リング注意と文脈並列
1枚のGPUに載らない超長文脈を、系列を割って複数デバイスへ分散し学習・推論する。リング状のKV交換と通信の重ね合わせで、近似なしにメモリを系列長に対し一定へ抑える勘所を掴めます。
- 1.文脈並列は系列(トークン軸)を割ってデバイスへ分散する第4の並列軸。各デバイスは自分の担当クエリだけを保持し全体のKVを持たないため、アクティベーションが系列長で頭打ちしない。
- 2.リング注意はKVブロックをリング状に隣へ順送りしながらブロック単位でアテンションを計算し、オンラインsoftmaxで逐次集約する。転送と計算を重ねるため、通信を計算の陰に隠せば追加コストがほぼ見えない。
- 3.結果はフルアテンションと厳密に等価で近似なし。1デバイスあたりのメモリは全系列長ではなくブロック長に依存するため、デバイス数を増やすほど扱える文脈長が線形に伸びる。
なぜ長文脈は1枚のGPUに載らないのか
Self-Attention の計算 で見たとおり、アテンションは各クエリが全キーを参照します。系列長を n、ヘッド次元を d とすると、素朴な実装はスコア行列を n × n で作るため、メモリが系列長に対し二乗で膨らみます。FlashAttention はこの n × n の中間行列を実体化せず、追加メモリを O(n) まで落としました。しかし O(n) でも、n が数十万〜数百万トークンに達すると、入力・出力のアクティベーションや各層の中間結果が積み上がり、単一デバイスのメモリを超えます。
問題の本質は、分散学習 で扱ったデータ並列・テンソル並列・パイプライン並列の3軸が、いずれも 系列(トークン)軸を割らない 点にあります。バッチを割っても層内の行列を割っても、1本のシーケンス全体は依然として各デバイスに丸ごと乗ります。そこで登場するのが、系列そのものを分割する第4の軸です。
文脈並列(context parallelism、シーケンス並列とも呼ぶ)は、既存の並列と競合せず 割る対象が違う だけです。データ並列がバッチを、テンソル並列が層内の行列を割るのに対し、文脈並列は 1本のシーケンスをトークン方向に分割 し、各デバイスに連続した区間を担当させます。3D並列にこの軸を加えた形が、超長文脈学習の定石です。
系列を割ったときの障壁:アテンションだけが「全体」を要る
系列を P 台のデバイスへ割り、各デバイス i がトークン区間 i 分の Q_i・K_i・V_i を持つとします。ここで層ごとの処理を見分けると、大半の演算はローカルで完結 します。埋め込み・全結合・LayerNorm・活性化 はトークンごとに独立なので、各デバイスは自分の区間だけで計算できます。
厄介なのはアテンションだけです。クエリ Q_i は自分の区間のキーだけでなく 全区間のキー K(= K_0 から K_{P-1} まで) を参照しなければなりません。つまり各デバイスは、自分が持っていない他デバイスのKVを何らかの形で手に入れる必要があります。
素朴な解は「全デバイスのKVを一度に集める」All-Gather です。しかしこれは全KVを各デバイスに実体化するため、メモリが再び全系列長 n に比例 してしまい、系列を割った意味が消えます。長文脈では K・V そのものが巨大なので、これは致命的です。求められるのは、全KVを同時に持たずに、全KVを順に見る 仕組みです。
リング注意:KVを隣へ回しながら少しずつ計算する
リング注意(Ring Attention)の着想は、分散学習 の Ring-AllReduce と同じ「リング」です。P 台のデバイスを論理的な輪に並べ、各デバイスは 自分の Q_i を固定したまま、KVブロックだけをリング上で順送り します。
各ステップで、デバイス i は「いま手元にあるKVブロック」と自分の Q_i でアテンションを計算し、結果を積算します。同時に、そのKVブロックを次の隣デバイスへ送り、前の隣デバイスから新しいKVブロックを受け取ります。これを P ステップ繰り返すと、各デバイスの Q_i は 全デバイス由来のKVを一巡 し、フルアテンションが完成します。
デバイス i(担当クエリ Q_i、初期KVは自分の K_i, V_i):
o, l, m を初期化 ← オンライン softmax の統計量
現在KV ← (K_i, V_i)
for step in 0 .. P-1:
# 計算と通信を同時に走らせる
非同期送信: 現在KV を 次の隣 (i+1) へ
非同期受信: 次に使うKV を 前の隣 (i-1) から
# 受信を待つ間にローカル計算
ブロック単位で S = Q_i · 現在KV.Kᵀ / √d
o, l, m を オンライン softmax で更新(下記)
通信の完了を待つ
現在KV ← 受信したKV
出力 O_i = o / l ← P 巡した時点で全KVを反映済み
重要なのは、各デバイスが 同時に保持するKVは1ブロック分だけ という点です。全KVを一度に集める All-Gather と違い、リングを回すことで メモリはブロック長に依存し、全系列長 n には依存しません。
鍵は「全KVを同時に持つ」ことと「全KVを順に見る」ことの分離です。All-Gather は前者で、全KVをマテリアライズするためメモリが n に比例します。リング注意は後者で、一度に手元にあるのは1ブロックだけ。処理し終えたブロックは次へ送って手放すので、時間軸に沿って全体を舐めても瞬間のメモリは一定に保てます。デバイス数 P を増やすほど1ブロックは小さくなり、扱える総文脈長が線形に伸びます。
オンライン softmax:ブロックを分割しても厳密に一致する
「一度に1ブロックずつ」処理して、なぜフルアテンションと同じ結果になるのか。ここを支えるのが、FlashAttention と同じ オンライン softmax です。softmax の分母は全キーにわたる指数和が要るため、本来は全KVを見終えるまで正規化を確定できません。しかし行ごとの最大値 m と指数和 l、部分出力 o を持ち回り、ブロックごとに補正すれば逐次集約できます。
新しいKVブロックのスコア s_j に対して:
m_new = max(m, ブロック内スコアの最大値)
α = exp(m - m_new) ← 旧統計量の再スケール係数
l = l * α + Σ exp(s_j - m_new)
o = o * α + Σ exp(s_j - m_new) * v_j
m = m_new
最後に: O_i = o / l ← 全ブロック走査後に一度だけ正規化
補正係数 α = exp(m - m_new) が肝で、後から大きいスコアが出て最大値が更新されたとき、過去に積んだ l と o をスケールし直します。この再スケールにより、KVを何ブロックに割って、どんな順に見ても、全キーを一度に見た softmax と厳密に一致 します。リング注意は近似アテンションではなく、線形アテンション のような精度トレードオフを持ちません。出力が数学的に等価なので、既存モデルの学習・推論にそのまま差し込めます。
通信と計算の重ね合わせ:リングが速い理由
リング注意の実用性は 通信と計算のオーバーラップ に懸かっています。各ステップでデバイスがやることは2つ――(1) 手元のKVで Q_i のアテンションを計算する、(2) KVブロックを隣へ送り次を受け取る。この2つを 同時に走らせる のが要点です。
| 方式 | 同時に持つKV | メモリ | 通信の隠蔽 | 結果 |
|---|---|---|---|---|
| All-Gather 方式 | 全デバイス分のKV | 系列長 n に比例(頭打ち) | 集約待ちが露出しやすい | 厳密 |
| リング注意 | 1ブロック分のみ | ブロック長に依存(n 非依存) | 計算の陰に隠せる | 厳密(等価) |
1ブロックのアテンション計算に要する時間を T_compute、1ブロックのKV転送に要する時間を T_comm とすると、両者を重ねたステップ時間は max(T_compute, T_comm) に近づきます。もし T_comm が T_compute 以下 なら、通信は計算の裏に完全に隠れ、リングを回す追加コストは実質ゼロになります。ブロック長を大きく取るほど T_compute が伸びて通信を吸収しやすくなる一方、瞬間メモリは増えるため、ここが設計上のつまみです。
オーバーラップは万能ではありません。ブロックを細かく割りすぎると T_compute が痩せて T_comm が露出し、リングのステップが通信律速になります。また因果マスク(各トークンが未来を見ない自己回帰)では、リング上の位置関係により 一部のブロックはマスクで全て捨てられ計算が要らない ため、素朴に回すと負荷が偏ります。実装では前半・後半のブロックを組み替えて各デバイスの計算量を均す(負荷分散する)工夫が入ります。
シーケンス並列との違いと、実務での位置づけ
用語が紛らわしいのが「シーケンス並列」です。文脈により2つの意味があります。狭義のシーケンス並列は、テンソル並列と組で使い、LayerNorm や Dropout などテンソル並列が割り残す部分をトークン方向に割ってアクティベーションメモリを節約する技法を指します。一方、リング注意が実現するのは アテンションまで含めてシーケンス全体を割る 文脈並列で、こちらの方が守備範囲が広い概念です。
| 観点 | 文脈並列(リング注意) | テンソル並列 | データ並列 |
|---|---|---|---|
| 割る軸 | 系列(トークン) | 層内の行列 | バッチ |
| アテンションの扱い | KVをリングで巡回し分担 | ヘッド/次元を分割 | 各デバイスが全系列を計算 |
| 主な削減対象 | 系列長由来のアクティベーション | 層内の重みと計算 | なし(複製) |
| 通信の中身 | KVブロックの順送り | 部分結果の AllReduce | 勾配の AllReduce |
| 適する範囲 | 超長文脈(学習・推論) | ノード内(NVLink) | 基本軸・ノード間可 |
実際の超長文脈システムでは、これらを掛け合わせます。ノード内はテンソル並列、その上に文脈並列でシーケンスを割り、さらにデータ並列で束ねる、という多軸構成です。文脈並列は KV キャッシュ と組めば、推論時の超長プロンプト処理(prefill)でも効きます。KVキャッシュが1本のシーケンスの過去KVを保持するのに対し、文脈並列はその1本自体を複数デバイスへ割るので、両者はレイヤーが違い併用できます。
- 文脈並列=系列(トークン軸)を割る第4の並列。層内演算はローカルで完結し、全体を要るのはアテンションだけ。
- リング注意はKVブロックをリング状に順送り し、
Q_iを固定してPステップで全KVを一巡。同時保持KVは1ブロックのみで、メモリは系列長nに非依存。 - オンライン softmax で逐次集約するため、フルアテンションと 厳密に等価(近似なし)。
- 実用性は 通信と計算のオーバーラップ。
T_commがT_compute以下なら通信は計算の陰に隠れ、追加コストがほぼ消える。因果マスク時は負荷分散が要る。
まとめ:系列を割るという発想
リング注意と文脈並列が示すのは、「どの軸で割り、その結果どんな通信が要るか」 という分散学習の一般問題に、系列軸という新しい答えを加えたことです。
| 論点 | 実態 | そこから言えること |
|---|---|---|
| なぜ割るのか | 長文脈のアクティベーションが1台に載らない | バッチ・層内では足りず系列軸の分割が要る |
| 障壁は何か | アテンションだけが全KVを参照する | 全KVを同時に持たず順に見る仕組みが要る |
| どう解くか | KVをリングで巡回+オンライン softmax | 瞬間メモリを一定に保ちつつ全体を集約できる |
| なぜ速いか | KV転送を計算の裏で重ねる | 通信 ≤ 計算なら追加コストがほぼ消える |
この見立てを持つと、「文脈長を伸ばすとOOMする」「デバイスを足しても長文が扱えない」といった壁が、分割軸と通信パターンの言葉で 説明できます。アテンション本体の効率化は FlashAttention と KV キャッシュ で、並列の全体像は 分散学習 で押さえると、超長文脈を学習・推論可能にする工夫が一本の線でつながります。
AI/機械学習 Article
リング注意と文脈並列を実務で読む
TL;DRは入口です。実際に選ぶ・使う段階では、何を解決するか、何と比較するか、導入後にどこで詰まるかまで見る必要があります。
解決すること
アテンション
比較で見る軸
難易度: advanced / カテゴリ: AI/機械学習 / タグ数: 5
導入後に効く点
リング注意はKVブロックをリング状に隣へ順送りしながらブロック単位でアテンションを計算し、オンラインsoftmaxで逐次集約する。転送と計算を重ねるため、通信を計算の陰に隠せば追加コストがほぼ見えない。
先に潰すリスク
用語だけ覚えても、設計・実装・運用でどこに効くかを確認しないと判断を誤る。
- 難易度
- advanced
- カテゴリ
- AI/機械学習
- タグ数
- 5
判断チェックリスト
- 自社の用途が「アテンション / 文脈並列」に近いか確認する。
- 強みである「文脈並列は系列(トークン軸)を割ってデバイスへ分散する第4の並列軸。各デバイスは自分の担当クエリだけを保持し全体のKVを持たないため、アクティベーションが系列長で頭打ちしない。」が本当に評価軸になるか確認する。
- 注意点の「用語だけ覚えても、設計・実装・運用でどこに効くかを確認しないと判断を誤る。」を運用で吸収できるか確認する。
- 公開値や仕様値は、対象プラン・対象機種・対象リージョンまで確認する。
- 既存システム、ID、ネットワーク、監視、バックアップとの接続方法を先に洗い出す。
- 小さく試してから、本番移行、権限設計、障害時手順、コスト監視を決める。