分散学習:データ並列・モデル並列・ZeRO
巨大モデルが1枚のGPUに載らない理由と、データ・テンソル・パイプライン並列の分割軸、ZeROの状態分割を押さえれば、学習を多GPUへ正しくスケールさせる勘所が一本でつながる。
- 1.並列化は分割する軸で3種に整理できる。データ並列はバッチを割りモデルを複製、テンソル並列は1つの層の行列を割り、パイプライン並列は層をステージに割って配置する。
- 2.データ並列の同期はAllReduceで全GPUの勾配を平均する。通信量はモデルサイズに比例しバッチサイズに依存しないため、スケールするほど通信が律速になりやすい。
- 3.ZeROはデータ並列の弱点であるメモリ重複を解消する。最適化状態・勾配・パラメータを各GPUに分割保持し、通信量をほぼ増やさずに学習可能なモデルサイズを数十倍へ広げる。
なぜ1枚のGPUでは学習できないのか
LLM の学習パイプライン で見たように、現代のモデルは数十億〜数千億パラメータに達します。問題はパラメータ本体だけではありません。学習中のGPUメモリには、次の4つが同時に乗ります。
- パラメータ本体
- 勾配(パラメータと同数)
- 最適化状態(Adam なら1次・2次モーメントでパラメータの2倍)
- アクティベーション(順伝播の中間結果。逆伝播で必要)
混合精度(FP16/BF16 で計算、FP32 でマスタ重みと状態を保持)の典型では、パラメータ1個あたり 16バイトほどかかります(FP16のパラメータ2+勾配2、FP32のパラメータ4+Adamモーメント2つで8)。これだけで 1B パラメータ=約16GB、アクティベーションは別枠です。単一GPUに載らないので、複数GPUへ「何を」「どの軸で」割るかが分散学習の核心になります。
データ並列・テンソル並列・パイプライン並列は競合する技術ではなく、割る対象が違うだけです。データ並列は入力バッチを、テンソル並列は層内の行列を、パイプライン並列は層の並びを割ります。大規模学習では3つを掛け合わせる(3D並列)のが定石です。
データ並列:モデルを複製しバッチを割る
最も基本的なのがデータ並列(Data Parallelism)です。全GPUに同じモデルの完全なコピーを置き、ミニバッチを GPU 台数で分割して各GPUに配ります。各GPUは自分の小バッチで順伝播・逆伝播し、ローカルな勾配を得ます。
ここで重要なのが、各GPUの勾配は別々のデータから計算されるのでそのままでは食い違う点です。パラメータの一致を保つには、更新前に全GPUの勾配を平均して揃える必要があります。この集団通信が AllReduce です。
各GPU: g_i = ∂L_i/∂w (自分の小バッチでの勾配)
AllReduce: g = (g_0 + g_1 + ... + g_{N-1}) / N ← 全GPUで同一値を共有
各GPU: w ← w - lr * g (同じ更新なので重みは一致を保つ)
実装上の主役が Ring-AllReduce です。N台をリング状に並べ、勾配を N 個のチャンクに分け、Reduce-Scatter(各チャンクの合計を1台に集める)と All-Gather(結果を全台へ配る)の2段で実現します。各GPUが送受信するデータ量は 2×(N-1)/N × モデルサイズで、台数 N が増えてもほぼ一定です。だからデータ並列はスループットでよくスケールします。
1つ目は通信。AllReduce の通信量はモデルサイズに比例し、小バッチを使うほど計算に対する通信の比率が上がって律速します。2つ目はメモリ重複。全GPUがパラメータ・勾配・最適化状態を丸ごと持つため、台数を増やしても1台に載るモデルサイズの上限は変わりません。この重複こそ後述の ZeRO が叩く弱点です。
モデル並列:モデル自体を割る
モデルが1台に載らないなら、モデル自体を分割します。割り方が2つあります。
テンソル並列:1つの層を割る
テンソル並列(Tensor Parallelism)は、層内の行列演算を複数GPUに横断分割します。例えば全結合 Y = X · W の重み W を列方向に [W_1, W_2] と割れば、各GPUが Y_1 = X·W_1、Y_2 = X·W_2 を計算し、結果を連結すれば Y になります。Transformer の Attention やMLPブロックは、この分割が数学的にきれいに収まるよう設計できます。
代償は通信頻度です。1つの層を割るため、順伝播・逆伝播のたびに部分結果を集約する AllReduce が層ごとに発生します。通信が極めて高頻度かつレイテンシに敏感なので、テンソル並列は NVLink で密結合した同一ノード内に閉じるのが鉄則です。
パイプライン並列:層の並びを割る
パイプライン並列(Pipeline Parallelism)は、モデルの層をステージに分けて別々のGPUに置く縦の分割です。GPU0 が第1〜8層、GPU1 が第9〜16層、というように担当を分け、データはステージ間をバケツリレーで流れます。
素朴にやると、後段のGPUは前段の出力を待つ間アイドルになります。この空白がパイプラインバブルです。対策がマイクロバッチ化で、ミニバッチを細かく刻んで次々と投入し、各ステージを常時稼働させてバブルを薄めます(GPipe / 1F1B などのスケジューリング)。通信はステージ境界のアクティベーション授受のみなので、テンソル並列より通信は軽く、ノード間にまたがらせやすいのが特徴です。
| 並列方式 | 割る対象 | 通信の中身 | 通信頻度 | 適する範囲 |
|---|---|---|---|---|
| データ並列 | 入力バッチ | 勾配のAllReduce | イテレーション毎に1回 | ノード間でも可。基本軸 |
| テンソル並列 | 層内の行列 | 部分結果のAllReduce | 層ごと(非常に高頻度) | ノード内(NVLink必須級) |
| パイプライン並列 | 層の並び(ステージ) | ステージ境界のアクティベーション | マイクロバッチ毎 | ノード間に展開しやすい |
ZeRO:データ並列の「メモリ重複」を消す
ここで スケーリング則 が示す「巨大化の圧力」と、データ並列のメモリ重複が正面衝突します。ZeRO(Zero Redundancy Optimizer) は、データ並列の使いやすさ(バッチを割るだけ)を保ったまま、その重複を排除する手法です。
着想は単純です。データ並列では N 台が同一の状態を N 重に持っています。これを丸ごと持たせず、各GPUが全体の 1/N だけを保持し、必要なときだけ通信で取り寄せれば、メモリ重複が消えます。分割対象を段階的に増やしたのが ZeRO の3ステージです。
| ステージ | 分割するもの | 削減対象 | 1台あたりメモリ(概念) | 追加通信 |
|---|---|---|---|---|
| ZeRO-1 | 最適化状態のみ | Adamモーメント等 | 状態を 1/N に | ほぼ増えない |
| ZeRO-2 | +勾配 | 勾配も 1/N に | 状態+勾配を 1/N に | ほぼ増えない |
| ZeRO-3 | +パラメータ | パラメータも 1/N に | 全状態を 1/N に | All-Gatherが追加で必要 |
最初に効くのが ZeRO-1です。前述のとおり最適化状態は最大のメモリ食い(混合精度Adamで全16バイト中の12バイト)で、これを 1/N に割るだけで効果が大きい。ZeRO-2 は勾配も分割します。データ並列の AllReduce はもともと Reduce-Scatter と All-Gather の2段なので、各GPUが「自分の担当パラメータ分の勾配だけ」を Reduce-Scatter で受け取れば、勾配を分割保持でき通信量は据え置きです。
ZeRO-3 はパラメータ本体まで割ります。各GPUは自分の担当パラメータしか持たないので、ある層を計算する瞬間にその層の全パラメータを All-Gather で一時的に集め、使い終わったら捨てる。これによりパラメータ・勾配・状態のすべてが 1/N になり、台数を増やすほど巨大モデルが載るようになります。
テンソル並列は1つの行列を恒久的に割って各GPUが部分計算を担いますが、ZeRO-3 はパラメータを保管時だけ分散し、計算の瞬間には All-Gather で完全な層を再構成します。各GPUは依然として全層を順番に計算する点でデータ並列のままです。だからモデルコードを書き換えずに巨大モデルへ届く一方、All-Gather の通信が増えるトレードオフを負います(PyTorch の FSDP も同じ発想です)。
通信とメモリのトレードオフを設計する
分散学習の設計は、結局メモリ・通信・計算の三者のやりくりです。代表的なメモリ節約策も、多くは通信や計算との交換で成り立ちます。
| 技術 | 減らすもの | 代償 | ひとことで |
|---|---|---|---|
| 勾配チェックポイント | アクティベーションメモリ | 逆伝播で順伝播を再計算(計算増) | 中間結果を捨て、必要時に作り直す |
| ZeRO-Offload / Infinity | GPUメモリ | CPU/NVMeへの転送(通信・遅延増) | 状態をホスト側に退避させ容量を稼ぐ |
| 勾配累積 | 実効バッチのメモリ | 更新頻度が下がる | 小バッチを複数回ためてから1回更新 |
| 混合精度(BF16) | メモリと帯域 | 数値範囲・精度の注意 | FP16/BF16で計算しFP32でマスタ保持 |
実際の超大規模学習では、これらを階層的に組みます。ノード内はテンソル並列(高速NVLinkを活かす)、ノード間はパイプライン並列、その外側をデータ並列(+ZeRO)で包むという 3D 並列が典型です。割り当ては「通信が高頻度なものほど内側(近い)GPUへ」という原則で決まります。
- 並列は割る軸で分類:データ=バッチ、テンソル=層内行列、パイプライン=層の並び。
- データ並列の同期=AllReduce(Ring)。通信量はモデルサイズ比例・台数にほぼ非依存。弱点はメモリ重複。
- テンソル並列は高頻度通信ゆえノード内、パイプライン並列はバブルをマイクロバッチで薄めノード間向き。
- ZeRO は状態→勾配→パラメータの順に 1/N 分割。ZeRO-1/2 は通信ほぼ不変、ZeRO-3 は All-Gather が増える。
まとめ:分割軸とAllReduceで全体像をつなぐ
分散学習は、一見バラバラな技術の寄せ集めに見えても、**「何をどの軸で割り、その結果どんな通信が要るか」**という1つの問いに還元できます。
| 論点 | 実態 | そこから言えること |
|---|---|---|
| なぜ割るのか | パラメータ・勾配・状態・活性化が1台に載らない | メモリ要求を台数で分担するのが目的 |
| データ並列の本質 | モデル複製+勾配のAllReduce平均 | スケールするが状態を全台が重複保持する |
| モデル並列の本質 | 層内(テンソル)か層間(パイプライン)で割る | 通信頻度が配置(ノード内/間)を決める |
| ZeROの本質 | 重複する状態を 1/N に分散保持 | データ並列のままモデル上限を大きく押し上げる |
この見立てを持つと、「GPUを増やしてもモデルが大きくできない」「ノードをまたぐと急に遅い」「メモリは足りるのに通信で頭打ち」といった現象が、分割軸と通信パターンの言葉で説明できます。学習対象を小さく抑える設計は パラメータ効率ファインチューニング と、最適化状態がメモリを支配する理屈は 最適化アルゴリズムの系統 と合わせて読むと、巨大モデルを学習可能にする工夫の全体像が線でつながります。
AI/機械学習 Article
分散学習:データ並列・モデル並列・ZeROを実務で読む
TL;DRは入口です。実際に選ぶ・使う段階では、何を解決するか、何と比較するか、導入後にどこで詰まるかまで見る必要があります。
解決すること
分散学習
比較で見る軸
難易度: advanced / カテゴリ: AI/機械学習 / タグ数: 5
導入後に効く点
データ並列の同期はAllReduceで全GPUの勾配を平均する。通信量はモデルサイズに比例しバッチサイズに依存しないため、スケールするほど通信が律速になりやすい。
先に潰すリスク
用語だけ覚えても、設計・実装・運用でどこに効くかを確認しないと判断を誤る。
- 難易度
- advanced
- カテゴリ
- AI/機械学習
- タグ数
- 5
判断チェックリスト
- 自社の用途が「分散学習 / データ並列」に近いか確認する。
- 強みである「並列化は分割する軸で3種に整理できる。データ並列はバッチを割りモデルを複製、テンソル並列は1つの層の行列を割り、パイプライン並列は層をステージに割って配置する。」が本当に評価軸になるか確認する。
- 注意点の「用語だけ覚えても、設計・実装・運用でどこに効くかを確認しないと判断を誤る。」を運用で吸収できるか確認する。
- 公開値や仕様値は、対象プラン・対象機種・対象リージョンまで確認する。
- 既存システム、ID、ネットワーク、監視、バックアップとの接続方法を先に洗い出す。
- 小さく試してから、本番移行、権限設計、障害時手順、コスト監視を決める。