活性化チェックポイントとメモリ最適化
GPUメモリ不足で層を積めない、そんな壁を活性化チェックポイントが崩す。逆伝播用の活性を捨てて再計算に置き換える原理と、平方根で効くメモリ削減、ZeROと重ねる勘所まで一本で押さえます。
- 1.活性化チェックポイント(再計算)は、逆伝播で使う中間活性をすべて保持せず一部だけ残し、必要になった時点で順伝播をやり直して復元する技法。活性メモリを計算時間と引き換えに削減する。
- 2.N層を等間隔にk個だけ保存すると活性メモリはおよそ O(N/k + k)、最適な区分数 k≈√N でメモリは O(√N) まで下がる。代償は順伝播1回ぶんの追加計算で、学習時間は約1.3〜1.5倍が目安。
- 3.削減されるのは活性メモリだけで、パラメータ・勾配・最適化状態は減らない。だから ZeRO の状態分割や混合精度と直交し、併用してこそ大規模学習のメモリが本当に回る。
なぜ活性を捨てたいのか:逆伝播はメモリを層数ぶん食う
順伝播は入力から出力へ一方向に流れるだけなので、素朴に考えれば各層の出力を次の層に渡した後は捨てられそうに見えます。ところが逆伝播はそれを許しません。連鎖律で勾配を計算するとき、多くの層は「順伝播時の入力(=前層の活性)」を必要とするからです(逆伝播の数理)。たとえば線形層 y = W x の重み勾配は dW = dy · x^T で、順伝播の入力 x がないと計算できません。ReLU も「どの要素が正だったか」というマスクを保持する必要があります。
このため標準的な自動微分は、順伝播中に生成した中間活性を計算グラフ上にすべて残し、逆伝播で消費します。結果として活性メモリは層数 N にほぼ比例して増え、深いネットワークや長い系列では、パラメータそのものより活性のほうがメモリを食う主因になります。バッチや系列長を上げると比例して膨らむのも活性です。ここで効くのが 活性化チェックポイント(activation checkpointing、勾配チェックポイント/再計算とも) です。
原理:保存する代わりに、必要になったら作り直す
発想は一言でいえば「メモリに残さず、逆伝播で要るときに順伝播をやり直して復元する」。順伝播ではごく少数の活性(チェックポイント)だけを残し、その間の中間活性はすべて解放します。逆伝播が該当区間に来たら、直近のチェックポイントを起点にその区間だけ順伝播を再実行して中間活性を一時的に作り直し、勾配を計算したらまた捨てます。時間をメモリと交換する技法で、勾配蓄積 と同じ「計算資源のトレードオフ」の系譜にあります。
# 通常: forward の全中間活性を保持したまま backward
y = block(x) # block 内部の活性がすべてメモリに残る
# チェックポイント: 入力 x だけ残し、内部活性は破棄。
# backward 時に block を再実行して内部活性を復元する
y = checkpoint(block, x) # forward は no_grad 相当で走らせ、活性を残さない
ポイントは、チェックポイント境界の活性さえ持っていれば、区間内はいつでも再現できるという決定性です。順伝播が決定的関数である限り、同じ入力から同じ中間活性が得られるので、保存を再計算で代替できます。逆にいうと、区間内に乱数や非決定的な副作用があると再計算で値がずれるため、後述の注意が要ります。
どれだけ減るか:等間隔チェックポイントと √N の法則
削減量は「どこにチェックポイントを置くか」で決まります。N 個の層を等間隔に区切って k 個のチェックポイントを置く古典的な戦略を考えます。このとき常時メモリに残るのは、(1) k 個のチェックポイント活性と、(2) 逆伝播中に1区間だけ再計算して展開する区間内の活性(約 N/k 個)です。合計はおおよそ次のようになります。
メモリ ≈ k + N/k (チェックポイント数 + 1区間の再計算分)
これを k で最小化すると k ≈ √N のとき最小 ≈ 2√N
つまり適切に区切れば活性メモリは O(N) から O(√N) へ落ちます。層数1万相当でも活性は百程度のオーダーに圧縮できる計算です。代償は再計算のための追加の順伝播で、区間ごとに1回ぶん余分に順伝播が走るため、学習全体では順伝播が実質2回・逆伝播が1回の構成になります。順伝播と逆伝播の計算量はおおむね1:2なので、追加コストは全体の約3割、学習時間は1.3〜1.5倍が実務的な目安です。
| 戦略 | 活性メモリ | 追加の再計算 | 使いどころ |
|---|---|---|---|
| 再計算なし(全保存) | O(N) | なし | メモリに余裕があり速度最優先 |
| 等間隔 k=√N 区分 | O(√N) | 順伝播ほぼ1回ぶん | 汎用。最も一般的な既定 |
| 全層チェックポイント | O(N)(各層境界のみ/内部は都度再計算) | 区間が細かく再計算増 | 極端にメモリが逼迫した場合 |
| 選択的(重い層のみ) | 層依存で中間 | 対象層だけ再計算 | Attention等ホットスポット狙い |
区分を細かくするほどメモリは減りますが、チェックポイント境界が増えて再計算区間が短くなり、境界の活性保存コストと再計算回数のバランスが崩れます。k ≈ √N は「保存」と「再計算」のコストが釣り合う点で、多くのフレームワークの既定挙動もこの近傍を狙います。
チェックポイントが削るのは中間活性のメモリであって、パラメータ本体・勾配バッファ・最適化状態(Adamのモーメント等)は一切減りません。混合精度Adamではパラメータ1個あたり16バイト級の状態が常駐しますが、これらは再計算の対象外です。したがって「モデルが大きすぎて重み・状態が載らない」問題には効かず、そこは後述の ZeRO の担当になります。活性が主因かパラメータ・状態が主因かを切り分けてから手を選ぶのが鉄則です。
実装上の落とし穴:乱数・BatchNorm・RNG状態
再計算は「同じ入力から同じ活性が再現できる」ことに依存します。ここに非決定的な演算が挟まると、順伝播時と再計算時で値がずれ、勾配が壊れます。代表がドロップアウトです。順伝播で引いたマスクと再計算で引くマスクが違えば別物になるため、フレームワークはRNG状態(乱数の種と位置)をチェックポイントに保存し、再計算時に復元して同じマスクを再現します。PyTorch の torch.utils.checkpoint が既定で RNG を保存するのはこのためで、自作の再計算ではここを取りこぼしやすい。
BatchNorm も注意が要ります。再計算で順伝播をもう一度走らせると、移動平均(running mean/var)の更新が二重に走る恐れがあり、統計がずれます。実装は再計算パスでの統計更新を抑制する必要があります。バッチ統計に依存しない LayerNorm・RMSNorm はこの種の副作用がなく、Transformer 系でチェックポイントが素直に機能する一因です(正規化層の相互作用の議論は勾配蓄積の記事も参照)。
- ドロップアウト等の乱数:RNG状態を保存・復元しないと、順伝播と再計算でマスクが変わり勾配が誤る。
checkpoint(..., preserve_rng_state=True)を維持する。 - BatchNormの統計更新:再計算パスで running 統計が二重更新されうる。統計更新の抑制、または統計非依存の正規化層を使う。
- インプレース演算・非決定的カーネル:再計算で値が一致しない演算は境界内に置かない。境界(保存する活性)は勾配が必要なテンソルであることを確認する。
大規模学習での位置づけ:ZeRO・混合精度と重ねる
超大規模学習のメモリは、活性・パラメータ・勾配・最適化状態の総和で決まります。活性化チェックポイントは活性の軸だけを叩くので、他の軸を叩く手法と直交して併用できます。ここが実務で最も効く点です。
- ZeRO(分散学習:データ並列・モデル並列・ZeRO) は最適化状態・勾配・パラメータを各GPUに
1/N分割し、状態側のメモリ重複を消します。活性には手を出しません。したがって「チェックポイントで活性を O(√N) に、ZeRO で状態を 1/N に」と役割分担でき、両方効かせて初めて巨大モデルが1枚のGPUの予算に収まります。ZeRO-Offload で状態をCPU/NVMeへ退避する構成とも競合しません。 - 混合精度学習 は活性・勾配をFP16/BF16にして各テンソルのバイト数を約半分にします。チェックポイントが活性の個数を減らすのに対し、混合精度は1個あたりのサイズを減らすので、これも直交します。両者を掛けると活性メモリは「半分 × O(√N)」で効きます。
- Attention のメモリ最適化 は、注意行列という特に巨大な中間活性を、そもそも物質化せずタイル計算で回避する専用手法です。汎用のチェックポイントと発想を共有しつつ、Attention 区間ではこちらの方が効率的な場合が多く、選択的チェックポイントの代わりに使われます。
- 目的:逆伝播に必要な中間活性を全保存せず、一部だけ残して再計算で復元し、活性メモリを計算時間と交換して削減する。
- 削減量:N層を等間隔
k区分で活性メモリ ≈k + N/k、最適k≈√Nで O(√N)。代償は順伝播ほぼ1回ぶんの追加計算(学習時間 約1.3〜1.5倍)。 - 限界:減るのは活性だけ。パラメータ・勾配・最適化状態は減らないので、そこは ZeRO が担当。
- 落とし穴:ドロップアウトは RNG状態を復元して再現、BatchNorm は統計の二重更新に注意。乱数・非決定演算を境界内に置かない。
まとめ:活性は「保存」ではなく「再現」できる資源
活性化チェックポイントは、順伝播が決定的である性質を突き、中間活性をメモリに残す代わりに逆伝播で作り直すことで、活性メモリを O(N) から O(√N) へ圧縮する技法でした。骨格は「境界の活性だけ残し、区間内は破棄、必要時に順伝播を再実行」。代償は順伝播ほぼ1回ぶんの追加計算で、時間でメモリを買う点は 勾配蓄積 と同じ精神です。ただし削るのは活性の軸だけ——パラメータ・勾配・最適化状態には効かないので、ZeRO の状態分割 や 混合精度 と重ねて初めて大規模学習のメモリ予算が回ります。再計算の落とし穴(RNG状態・BatchNorm統計)を押さえたうえで、メモリの主因がどの軸かを見極め、直交する手法を積み上げるのが最適化の要諦です。
AI/機械学習 Article
活性化チェックポイントとメモリ最適化を実務で読む
TL;DRは入口です。実際に選ぶ・使う段階では、何を解決するか、何と比較するか、導入後にどこで詰まるかまで見る必要があります。
解決すること
活性化チェックポイント
比較で見る軸
難易度: advanced / カテゴリ: AI/機械学習 / タグ数: 5
導入後に効く点
N層を等間隔にk個だけ保存すると活性メモリはおよそ O(N/k + k)、最適な区分数 k≈√N でメモリは O(√N) まで下がる。代償は順伝播1回ぶんの追加計算で、学習時間は約1.3〜1.5倍が目安。
先に潰すリスク
用語だけ覚えても、設計・実装・運用でどこに効くかを確認しないと判断を誤る。
- 難易度
- advanced
- カテゴリ
- AI/機械学習
- タグ数
- 5
判断チェックリスト
- 自社の用途が「活性化チェックポイント / メモリ最適化」に近いか確認する。
- 強みである「活性化チェックポイント(再計算)は、逆伝播で使う中間活性をすべて保持せず一部だけ残し、必要になった時点で順伝播をやり直して復元する技法。活性メモリを計算時間と引き換えに削減する。」が本当に評価軸になるか確認する。
- 注意点の「用語だけ覚えても、設計・実装・運用でどこに効くかを確認しないと判断を誤る。」を運用で吸収できるか確認する。
- 公開値や仕様値は、対象プラン・対象機種・対象リージョンまで確認する。
- 既存システム、ID、ネットワーク、監視、バックアップとの接続方法を先に洗い出す。
- 小さく試してから、本番移行、権限設計、障害時手順、コスト監視を決める。