GQA/MQA:KVキャッシュを削減するアテンション変種
LLM 推論の帯域とメモリを食う KV キャッシュを、Key/Value ヘッドの共有で数分の一に圧縮する。MQA・GQA の原理と品質トレードオフ、MHA からの低コスト変換まで一気に押さえる。
- 1.MHA は Query ヘッドと同数の K/V ヘッドを持つが、推論の Decode はメモリ帯域律速で、毎ステップ KV キャッシュ全体を HBM から読む。MQA は全 Query ヘッドで K/V を1組に共有し、GQA はグループ単位で共有する中間案。
- 2.GQA を g グループにすると KV キャッシュと KV 読み出し帯域は MHA 比で約 g/h(h は Query ヘッド数)に縮む。Llama 2 70B は 64 Q ヘッドを 8 KV グループにして KV を約 1/8 にした。算術強度が上がり Decode が速くなる。
- 3.MHA を MQA/GQA へ後付けするには、各グループの K/V ヘッドを平均プールして初期化し、元の事前学習計算量の数%(uptraining)で微調整する。MQA は品質劣化が出やすく、GQA はほぼ無劣化で主流。
なぜ KV ヘッドを減らしたいのか
Multi-Head Attention の内部 で見たとおり、標準の MHA(Multi-Head Attention) は h 個のヘッドそれぞれに独立した Query・Key・Value を持たせます。学習時はこれで問題ありません。困るのは自己回帰の推論です。
KV キャッシュ で扱ったように、生成の Decode フェーズは メモリ帯域律速です。1トークンを出すたびに、これまで蓄えた KV キャッシュ全体を HBM(GPU の高帯域メモリ)から読み直す必要があり、ボトルネックは行列積の演算ではなくデータの読み出し量にあります。MHA では KV キャッシュ量がヘッド数に比例して大きいため、ここが直接の足かせになります。
「ヘッド数を減らさず、K/V ヘッドだけ減らすと何が起きるか」を原理から見ます。KV キャッシュそのものの仕組みは KV キャッシュ を前提とします。
算術強度(arithmetic intensity)で見る Decode
なぜ KV を読む量が支配的なのかは、算術強度=「メモリ1バイトあたり何回の演算をするか」で説明できます。Decode の1ステップで考えると、Query は1トークン分だけですが、K・V は系列長 n 件分を読みます。
Decode 1ステップの概算
読み出し量 ∝ モデル重み + KVキャッシュ(系列長 n に比例)
演算量 ∝ 上記重み・KVとの内積(同オーダー)
→ 1要素を読むたびに「数回」しか掛け算しない = 算術強度が低い
算術強度が低いと、GPU の演算器はデータ到着待ちで遊びます。だから Decode を速くする近道は「演算を速くする」ことではなく「読むデータを減らす」ことです。K・V のヘッド数を削れば、KV キャッシュも、毎ステップの KV 読み出しも、両方が同じ比率で縮みます。これが MQA/GQA の動機です。
MQA:K/V を1組に共有する
MQA(Multi-Query Attention) は、h 個の Query ヘッドを残したまま、Key と Value のヘッドを1組だけに減らします。全 Query ヘッドが同じ K・V を参照する形です。
MHA: Qヘッド h 個 ↔ Kヘッド h 個 / Vヘッド h 個 (1対1)
MQA: Qヘッド h 個 ↔ Kヘッド 1 個 / Vヘッド 1 個 (h 対1で共有)
KV キャッシュ量は KVヘッド数 に比例するので、h から1へ減らせば理論上 約 1/h に圧縮されます(例:h=64 なら約 1/64)。Decode の KV 読み出し帯域も同じ比率で下がり、生成スループットが大きく改善します。
代償は表現力です。各 Query ヘッドは「どこを見るか」を独立に決められますが、見る対象(K/V)が全ヘッドで共通になるため、ヘッドごとに異なる部分空間で関係を捉える MHA の利点が削がれます。実測では、特に複雑な推論や長文で品質低下と学習の不安定化が報告されており、ここが GQA を生んだ背景です。
GQA:グループ単位で共有する中間案
GQA(Grouped-Query Attention) は MHA と MQA の連続的な中間です。h 個の Query ヘッドを g 個のグループに分け、グループごとに K/V を1組共有します。
GQA(g グループ): Qヘッド h 個を g 分割
グループ1: Q[1..h/g] ↔ K/V 1組
グループ2: Q[...] ↔ K/V 1組
...
→ KVヘッド数 = g
g = hのとき GQA は MHA と一致(各グループ1ヘッド)。g = 1のとき GQA は MQA と一致(全体で1組)。
つまり GQA は1つのパラメータ g で MHA↔MQA を補間します。KV キャッシュは MHA 比で約 g/h に縮みます。少数のグループ(例 8)を残すだけで、各部分空間の多様性をある程度保てるため、MQA の品質劣化を避けつつ MQA に近い圧縮を得られます。
| 方式 | KVヘッド数 (Qヘッド h=64) | KVキャッシュ量(MHA比) | 品質への影響 |
|---|---|---|---|
| MHA | 64 | 基準(最大) | 最良だが KV が最も重い |
| GQA(g=8) | 8 | 約 1/8 | ほぼ劣化なし。現行の主流 |
| MQA(g=1) | 1 | 約 1/64 | 最小だがタスクにより品質低下・不安定化 |
グループ数は テンソル並列数(GPU 枚数)の倍数に取ると都合が良いことが多いです。各 GPU に少なくとも1つの KV ヘッドが乗るよう g を並列度に合わせると、KV ヘッドの複製や偏在を避けられます。Llama 2/3 系が g=8 を使うのはこの実装事情とも噛み合っています。
どれだけ削れるか:見積もりの実際
KV キャッシュ量は次の式で決まります(KV キャッシュ の式の KVヘッド数 がここで効く)。
KVキャッシュ量 ≒ 2 × 層数 × KVヘッド数 × head_dim × 系列長 × バッチ × 精度バイト数
↑ ここを h から g へ減らす
h=64 を g=8 にすれば、層数・系列長・バッチが同じでも KV は約 1/8。Decode は毎ステップこの KV を読むので、KV 読み出し帯域も約 1/8 になり、算術強度が上がって Decode が速くなります。同じ GPU メモリなら、空いた分でバッチを増やす/コンテキストを伸ばすこともでき、サービングのスループット上限が押し上がります。
GQA/MQA が減らすのは K/V ヘッドだけで、Query ヘッド数 h は MHA のまま維持します。つまりアテンションの「視点の数」は保ったまま「見る素材」を共有する設計です。Query まで減らすわけではない点を取り違えると、品質トレードオフの理解がずれます。
MHA から変換する:uptraining
魅力的なのは、既存の MHA モデルを GQA/MQA へ後付けで変換できることです。ゼロから事前学習し直す必要はありません。手順は概念的に2段階です。
1. 構築(checkpoint conversion):
各グループに属する複数の K/V ヘッドの投影行列を「平均プール」して
1組の K/V 投影に畳み込む(合計してヘッド数で割る)。
→ 平均は単一ヘッドへの初期化として、和や先頭ヘッド採用より安定。
2. 微調整(uptraining):
元の事前学習計算量の数%(論文では α≈5%)だけ追加学習し、
平均プールで生じたズレを回復させる。
ポイントはコストの小ささです。元の事前学習に比べわずかな計算(5%前後)で、MHA 並みの品質を GQA で取り戻せると報告されています。MQA も同様に変換できますが、回復後でも GQA より品質が落ちやすく、学習が不安定になりがちです。
| 論点 | MHA→MQA | MHA→GQA |
|---|---|---|
| KV 圧縮率 | 最大(約 1/h) | 中(約 g/h、g は調整可) |
| uptraining 後の品質 | 劣化が残りやすい | ほぼ MHA 同等 |
| 学習安定性 | 不安定になりやすい | 安定 |
| 主な用途 | 極端な帯域制約・軽量モデル | 現行 LLM の標準的選択 |
なぜ平均プールなのか
変換で 和 や 先頭ヘッド採用 ではなく 平均を選ぶ理由は、初期化のスケールにあります。複数ヘッドの投影行列を単純加算すると出力の大きさが グループ内ヘッド数 倍にずれ、softmax 前のロジット分布が壊れて学習が荒れます。平均なら各 K/V の出力スケールが元の1ヘッドと同程度に保たれ、変換直後でも元モデルにかなり近い振る舞いから微調整を始められます。これが少ない uptraining で済む理由です。
- MHA=Q と K/V が1対1。MQA=K/V を全体で1組共有。GQA=グループ単位で K/V を共有し
g=hで MHA、g=1で MQA に一致。 - 削減対象は KV キャッシュと KV 読み出し帯域で、MHA 比 約 g/h。Decode はメモリ帯域律速なのでこれがそのまま高速化につながる。
- uptraining=K/V ヘッドを平均プールして初期化し、事前学習の数%だけ微調整して MHA→GQA/MQA に変換する。GQA はほぼ無劣化、MQA は劣化が残りやすい。
全体像:帯域を起点に設計を選ぶ
GQA/MQA は「品質をできるだけ保ちつつ、Decode の帯域とメモリを削る」ための連続パラメータ g の選択問題です。g を小さくするほど KV は縮みますが、共有が進んで品質リスクが上がります。実務では g=8 前後が圧縮率と品質の良いバランス点として定着しています。
| 観点 | 実態 | そこから言えること |
|---|---|---|
| なぜ削るのか | Decode はメモリ帯域律速で KV 読み出しが支配的 | K/V ヘッドを減らせば帯域とメモリが同率で縮む |
| どう削るのか | GQA はグループ単位で K/V を共有(MHA↔MQA を補間) | g で圧縮率と品質を連続的に調整できる |
| 品質はどうか | GQA はほぼ無劣化、MQA は劣化が残りやすい | 現行の標準は GQA、MQA は極端な制約向け |
| 既存モデルは | 平均プール初期化+数%の uptraining で変換可 | ゼロから学習し直さず後付け移行できる |
この見立てを持つと、モデル選定で「GQA グループ数がいくつか」を見る意味、長文・大バッチで GQA モデルが有利な理由、そして FlashAttention のような計算段取りの最適化と GQA の読み出し量そのものの削減が別レイヤーで両立する理由が、すべて「帯域」という一つの軸で線につながります。
AI/機械学習 Article
GQA/MQA:KVキャッシュを削減するアテンション変種を実務で読む
TL;DRは入口です。実際に選ぶ・使う段階では、何を解決するか、何と比較するか、導入後にどこで詰まるかまで見る必要があります。
解決すること
LLM
比較で見る軸
難易度: advanced / カテゴリ: AI/機械学習 / タグ数: 5
導入後に効く点
GQA を g グループにすると KV キャッシュと KV 読み出し帯域は MHA 比で約 g/h(h は Query ヘッド数)に縮む。Llama 2 70B は 64 Q ヘッドを 8 KV グループにして KV を約 1/8 にした。算術強度が上がり Decode が速くなる。
先に潰すリスク
用語だけ覚えても、設計・実装・運用でどこに効くかを確認しないと判断を誤る。
- 難易度
- advanced
- カテゴリ
- AI/機械学習
- タグ数
- 5
判断チェックリスト
- 自社の用途が「LLM / 推論最適化」に近いか確認する。
- 強みである「MHA は Query ヘッドと同数の K/V ヘッドを持つが、推論の Decode はメモリ帯域律速で、毎ステップ KV キャッシュ全体を HBM から読む。MQA は全 Query ヘッドで K/V を1組に共有し、GQA はグループ単位で共有する中間案。」が本当に評価軸になるか確認する。
- 注意点の「用語だけ覚えても、設計・実装・運用でどこに効くかを確認しないと判断を誤る。」を運用で吸収できるか確認する。
- 公開値や仕様値は、対象プラン・対象機種・対象リージョンまで確認する。
- 既存システム、ID、ネットワーク、監視、バックアップとの接続方法を先に洗い出す。
- 小さく試してから、本番移行、権限設計、障害時手順、コスト監視を決める。