FlashAttention:メモリ効率的なアテンション計算
アテンションが長文で急に遅く重くなる正体は HBM への中間行列の書き戻し。タイル化とオンライン softmax でそれを断てば、計算結果を変えずに速度とメモリを同時に取り戻せる。
- 1.標準アテンションは n×n のスコア行列と softmax 結果を HBM に書き戻すため、計算量より HBM 入出力が律速する。FlashAttention は中間行列を HBM に一切書かず、ブロック単位で SRAM 上だけで完結させる IO-aware 設計。
- 2.鍵はオンライン softmax。行ごとの最大値と指数和を走らせながら更新し、ブロックを順に処理しても全列を一度に見たときと数学的に同一の結果を得る。近似ではなく厳密に等価。
- 3.アテンションのメモリ使用量を系列長 n に対し O(n²) から O(n) へ落とし、HBM アクセスを大幅削減して2〜4倍前後の高速化を得る。Prefill・Decode の双方で効き、長文ほど効果が大きい。
アテンションの真のボトルネックは計算ではない
Self-Attention の計算 で見たとおり、アテンションの中身は次の式です。
Attention(Q, K, V) = softmax(QKᵀ / √d_k) · V
系列長を n、ヘッド次元を d とすると、QKᵀ は n × n の スコア行列 S になります。これを n が大きいほど巨大化します。素朴な実装は、この S を作り、softmax をかけて P(同じく n × n)にし、P·V を計算します。FLOPs は確かに O(n²·d) ですが、実機の遅さの主因は FLOPs ではありません。本当のボトルネックは、この n × n の中間行列を GPU の主記憶(HBM)に書き出し、また読み戻す メモリ往復にあります。
FlashAttention の論文が掲げる IO-aware とは、アルゴリズムを「演算回数」ではなく「メモリ階層間のデータ移動量」で設計するという立場です。同じ計算結果でも、どこに何を置きいつ動かすかで実速度が桁で変わる、という前提に立ちます。
GPU メモリ階層:SRAM と HBM の決定的な差
最適化の理由は GPU のメモリ階層にあります。容量と帯域が大きく違う2階層を区別するのが出発点です。
| 階層 | 容量の目安 | 帯域の目安 | 役割 |
|---|---|---|---|
| SRAM(オンチップ共有メモリ) | 1 SM あたり 数百 KB | おおむね数十 TB/s(桁違いに速い) | 計算ユニットの目の前。極小だが超高速 |
| HBM(GPU 主記憶) | 数十 GB(例 A100 で 40〜80GB) | おおむね 1〜3 TB/s | テンソルやモデル全体を置く。大容量だが SRAM より遅い |
ポイントは、SRAM が HBM より おおよそ一桁速い一方、容量は HBM の数万分の一 しかないことです。n × n のスコア行列は系列長が数千ともなれば SRAM に収まりません。だから標準実装は S と P を HBM へ書かざるを得ず、softmax 行ごとに「書く・読む」を繰り返します。この HBM 入出力が律速する 状態を memory-bound と呼びます。アテンションはまさにこれに当たり、演算ユニットは HBM 待ちで遊んでいます。
アテンションが「O(n²)」と言われるとき、多くの文脈で痛いのは計算量より メモリ量 です。n × n の中間行列を実体化(マテリアライズ)すると、系列長に対して二乗でメモリと HBM 往復が増えます。長文で OOM したり急に重くなる主因はここです。FlashAttention はこの中間行列を実体化しないことで、この二乗を消します。
オンライン softmax:全列を見ずに正規化する
中間行列を HBM に書かないには、K・V を縦に小ブロックへ割り、ブロックを順に SRAM へ載せて処理する必要があります。ところが素朴な softmax は分母に 全列の指数和 が要るため、本来は全 K を見終えるまで正規化を確定できません。これを解くのが オンライン softmax です。
数値安定化のため softmax は通常、行の最大値 m を引いてから指数を取ります。ブロックを順に見ていく過程で、それまでの最大値 m と指数和 l、そして部分的な出力 o を持ち回り、新ブロックごとに 補正しながら更新 します。
各クエリ行について、ブロックを順に処理:
m_new = max(m, ブロック内スコアの最大値)
α = exp(m - m_new) ← 旧最大値が更新された分の補正係数
l = l*α + Σ exp(s_j - m_new) ← 指数和を新基準にスケールし直して加算
o = o*α + Σ exp(s_j - m_new) * v_j ← 出力も同じ係数で補正して加算
m = m_new
最後に: 出力 = o / l ← 全ブロック走査後に一度だけ正規化
肝は補正係数 α = exp(m - m_new) です。後から大きいスコアが出て基準の最大値が更新されたとき、過去に積み上げた l と o を α 倍してスケールを揃え直します。この再スケールにより、ブロックを分割して逐次処理しても、全列を一度に見て softmax した結果と厳密に一致 します。FlashAttention は近似ではなく、出力が数学的に等価である点が決定的です。
線形アテンションや疎アテンションは計算を間引いて速くしますが、出力が変わるため精度トレードオフが生じます。FlashAttention は softmax を 計算する順序と置き場所だけ を変え、値は一切近似しません。だから既存モデルへ差し替えても出力が変わらず、学習・推論のどちらにも安全に入れられます。
タイル化:HBM 往復を消すループ構造
オンライン softmax を使えば、Q・K・V をブロック(タイル)に分け、二重ループで処理できます。外側を Q のブロック、内側を K・V のブロックで回します。
for each Q ブロック Qi (SRAM に載せる):
o, l, m を初期化
for each K,V ブロック Kj, Vj (SRAM に載せる):
Sij = Qi · Kjᵀ / √d_k ← SRAM 上で計算。HBM に書かない
オンライン softmax で o, l, m を更新(上の手順)
Oi = o / l
Oi だけを HBM に書き戻す ← 出力 n×d のみ。中間 n×n は書かない
このループでは、SRAM に載るのは小さなタイルだけで、巨大な S や P は 一度も HBM に実体化されません。HBM へ書き戻すのは最終出力 O(n × d、入力と同サイズ)だけです。結果としてアテンションの 追加メモリは O(n²) から O(n) へ下がり、HBM 入出力量も大幅に減ります。memory-bound だった処理が HBM 待ちから解放され、実測でおおむね 2〜4倍 の高速化が報告されています(モデル・系列長・ハードに依存)。
| 観点 | 標準アテンション | FlashAttention |
|---|---|---|
| 中間行列 n×n | HBM に実体化して書き戻す | 実体化しない。SRAM 上で消費 |
| 追加メモリ量 | O(n²) | O(n) |
| HBM 入出力 | 多い(律速要因) | 大幅削減 |
| 出力の値 | — | 標準と厳密に等価(近似なし) |
| 律速 | メモリ帯域(HBM) | 演算寄りに改善 |
逆伝播:保存しない代わりに再計算する
学習時の逆伝播は一見やっかいです。勾配計算には n × n の P が要りますが、FlashAttention はそれを保存していません。ここで取る戦略が 再計算(recomputation) です。順伝播では各行の正規化に使った m と l(各 O(n))だけを保存しておき、逆伝播で必要になったタイルの S・P を その場で作り直します。
一見すると無駄な再計算ですが、演算をやり直すコストより、巨大な P を HBM に置いて読み書きするコストの方が高い のが現代 GPU の実情です。だから「保存せず再計算」が正味で得になります。これは 自動微分 の文脈で言う勾配チェックポイントと同じ発想を、アテンション内部に特化して埋め込んだものと捉えられます。
- FlashAttention は 計算順序と置き場所だけ を変える IO-aware 設計。出力は標準と 厳密に等価(近似ではない)。
- 鍵は オンライン softmax:最大値
mと指数和lを補正係数α = exp(m - m_new)で再スケールしながら更新し、ブロック逐次でも全列 softmax と一致させる。 - タイル化 で
n × nを HBM に書かず SRAM で完結。追加メモリはO(n²)→O(n)、HBM 入出力を削減。逆伝播は 再計算 でPを持たずに済ませる。
KV キャッシュとの関係:別レイヤーの最適化
混同しやすいのが KV キャッシュ との違いです。両者は 別レイヤーの最適化 であり、競合せず併用します。
| 観点 | FlashAttention | KV キャッシュ |
|---|---|---|
| 何を最適化するか | 1回のアテンション計算の段取り | ステップ間での K/V の再利用 |
| 削るもの | 中間行列の HBM 往復 | 過去トークンの K/V 再計算 |
| 効くフェーズ | Prefill・Decode 双方 | 主に Decode |
| 性質 | 計算の置き場所を変える | 計算結果を使い回す |
FlashAttention は softmax の計算過程 を効率化し、KV キャッシュは K/V の再計算 を省く。前者は段取りの最適化、後者は結果の使い回しで、レイヤーが違うため両立します。実際の LLM サービングでは、PagedAttention・連続バッチングと並んで FlashAttention が標準的に組み込まれています。
まとめ:データ移動で考えるという発想
FlashAttention が示したのは、同じ数式・同じ結果でも、メモリ階層をどう使うかで実速度は桁で変わる という事実です。
| 論点 | 実態 | そこから言えること |
|---|---|---|
| なぜ遅いか | n×n 中間行列の HBM 往復が律速(memory-bound) | FLOPs を減らすより HBM 入出力を減らす方が効く |
| どう解くか | タイル化+オンライン softmax で SRAM 完結 | 中間行列を実体化せず、出力だけ書き戻す |
| なぜ安全か | softmax の順序を変えるだけで値は厳密等価 | 既存モデルに無改造で差し替え可能 |
| 学習はどうか | P を保存せず逆伝播で再計算 | 再計算コスト < HBM 往復コスト ゆえ得 |
この「IO-aware」――処理を演算量ではなくデータ移動量で捉える視点は、アテンションに限らず GPU 上の高速化全般に通じます。アテンション本体の数理は Self-Attention の計算 と Multi-Head Attention で、推論時のメモリ設計は KV キャッシュ で押さえると、長文 LLM の速度とメモリの全体像が線でつながります。
AI/機械学習 Article
FlashAttention:メモリ効率的なアテンション計算を実務で読む
TL;DRは入口です。実際に選ぶ・使う段階では、何を解決するか、何と比較するか、導入後にどこで詰まるかまで見る必要があります。
解決すること
アテンション
比較で見る軸
難易度: advanced / カテゴリ: AI/機械学習 / タグ数: 5
導入後に効く点
鍵はオンライン softmax。行ごとの最大値と指数和を走らせながら更新し、ブロックを順に処理しても全列を一度に見たときと数学的に同一の結果を得る。近似ではなく厳密に等価。
先に潰すリスク
用語だけ覚えても、設計・実装・運用でどこに効くかを確認しないと判断を誤る。
- 難易度
- advanced
- カテゴリ
- AI/機械学習
- タグ数
- 5
判断チェックリスト
- 自社の用途が「アテンション / Transformer」に近いか確認する。
- 強みである「標準アテンションは n×n のスコア行列と softmax 結果を HBM に書き戻すため、計算量より HBM 入出力が律速する。FlashAttention は中間行列を HBM に一切書かず、ブロック単位で SRAM 上だけで完結させる IO-aware 設計。」が本当に評価軸になるか確認する。
- 注意点の「用語だけ覚えても、設計・実装・運用でどこに効くかを確認しないと判断を誤る。」を運用で吸収できるか確認する。
- 公開値や仕様値は、対象プラン・対象機種・対象リージョンまで確認する。
- 既存システム、ID、ネットワーク、監視、バックアップとの接続方法を先に洗い出す。
- 小さく試してから、本番移行、権限設計、障害時手順、コスト監視を決める。