多トークン予測(Medusa・EAGLE)
別モデルを載せずに LLM を速くしたいなら多トークン予測。本体にヘッドを足す Medusa と、特徴量を先読みする EAGLE の原理とツリー検証を押さえれば、投機の主流手法を選び分けられる。
- 1.1回の前方計算で先の複数トークンを当てにいく手法。Medusa は本体の最終隠れ状態に複数の予測ヘッドを増設し t+2, t+3… を並列に候補化、EAGLE は出力直前の特徴量を自己回帰的に先読みして本体の LM ヘッドで token 化する。
- 2.EAGLE の勘所は、token 列より特徴量列のほうが滑らかで予測しやすいという観察。次に確定した token を1つ遅らせて入力に混ぜ、サンプリングの不確実性を吸収する。EAGLE は棄却サンプリングで無損失、Medusa は既定が typical acceptance で厳密無損失とは限らない。
- 3.いずれも複数候補を1本の列ではなくツリーに束ね、専用の注意マスクで1回の前方計算に詰めて一括検証する(ツリー検証)。EAGLE-2/Medusa-2 は木の形や学習を動的化し受理長を伸ばす。
なぜ「1ステップで複数トークン」を狙うのか
KV キャッシュと LLM 推論の最適化 で見たとおり、LLM の Decode はメモリ帯域律速です。1トークン出すだけでも全パラメータと KV を HBM から読み出す必要があり、演算ユニットは大きく余っています。ここで効くのが、投機的デコーディング と同じ発想、すなわち1回の前方計算で複数トークンを確定させて読み出しコストを割り勘することです。
標準の投機的デコーディングは、そのために別個の小さなドラフトモデルを用意しました。しかし別モデルは、語彙を揃える・メモリに2つ載せる・整合性を保つといった運用コストを伴います。これを避け、ターゲットモデル自身に先読みさせるのが本記事の主題、**多トークン予測(multi-token prediction)**です。代表格が Medusa と EAGLE で、いずれも「本体にわずかな部品を足すだけ」で複数トークンの候補を生み、それを一括検証します。
検証の数理(修正版棄却サンプリング、受理確率 min(1, p/q)、無損失性)は 投機的デコーディング で詳説済みです。本稿は候補をどう作るか、とりわけ Medusa の複数ヘッドと EAGLE の特徴量予測、そしてツリー検証に焦点を絞ります。
Medusa:本体に複数の予測ヘッドを足す
Transformer は通常、最終層の隠れ状態 h_t を1つの LM ヘッド(線形層+softmax)に通して「次の1トークン」を出します。Medusa は、この h_t の上に追加の軽量ヘッドを複数(Medusa ヘッド)並べます。
- ヘッド0(元の LM ヘッド): 次トークン
t+1を予測(従来どおり)。 - Medusa ヘッド1: 2手先
t+2を予測。 - Medusa ヘッド2: 3手先
t+3を予測。 - …と
t+kまで拡張。
各 Medusa ヘッドは「同じ h_t から、複数手先を直接当てる」小さな全結合ブロック(残差付き)で、本体パラメータは凍結したままヘッドだけを追加学習します。これが Medusa の第一の利点です。別モデルも、本体の再学習も基本要りません。
┌─→ LM ヘッド(既存) ─→ t+1 の候補分布
h_t ────────┼─→ Medusa ヘッド1 ─→ t+2 の候補分布
(最終隠れ状態) ├─→ Medusa ヘッド2 ─→ t+3 の候補分布
└─→ Medusa ヘッド3 ─→ t+4 の候補分布
すべて同一の前方計算1回で並列に得られる
ここで一つ、原理的な弱点があります。各ヘッドは同じ h_t だけを見て独立に先を当てるため、t+2 を予測する際に「t+1 が実際に何になったか」を条件にできません。先の位置ほど不確実性が積み上がり、単独の当たり精度は落ちます。Medusa はこれを、各ヘッドから上位数候補を取り出して組み合わせのツリーを作り(後述)、多めの経路を一括検証することで補います。当たりの1本を引ければよい、という設計です。
標準の投機的デコーディングは修正版棄却サンプリングで分布が完全一致します。一方 Medusa は既定の実装で**典型受理(typical acceptance)**を用いることが多く、これは「ターゲット分布から見て十分ありそうな候補」を温度としきい値で受理する緩い基準です。厳密な棄却サンプリングではないため、設定によっては出力分布がターゲット単独と厳密には一致しません。速度を優先し、多少の分布ずれを許容する割り切りです(貪欲デコード相当の設定なら実質一致します)。
Medusa には代表的に2段階の学習があります。Medusa-1 は本体を凍結しヘッドだけ学習する軽量版、Medusa-2 は本体も含めて共同微調整し受理長を伸ばす版です。手軽さの Medusa-1、性能の Medusa-2 という住み分けです。
EAGLE:token ではなく「特徴量」を先読みする
EAGLE(Extrapolation Algorithm for Greater Language-model Efficiency)は、Medusa とは別角度から多トークン予測を実現します。核心の観察はこうです。token の系列は離散でサンプリング由来のノイズを含み予測しにくいが、出力直前の「特徴量(feature、=最終層の隠れ状態)」の系列は連続的で滑らかであり、次を外挿しやすい。そこで EAGLE は、token 列ではなく特徴量列のほうを自己回帰的に先読みします。
仕組みは次のとおりです。EAGLE は小さな自己回帰ヘッド(1層 Transformer 程度)を1つ持ち、これに「これまでの特徴量」と「直前に確定した token の埋め込み」を入力して次の特徴量を予測します。得られた特徴量をターゲット本体と同じ LM ヘッドに通せば、token の候補分布が出ます。これを数手ぶん繰り返してドラフト列(実際にはツリー)を作り、最後に本体で一括検証します。
EAGLE の先読み(1ステップ):
入力 = [ 過去の特徴量 f_(≤t) , 直前に確定した token の埋め込み e_t ]
↓ 小さな自己回帰ヘッド(1層 Transformer 相当)
次の特徴量 f_(t+1) を予測
↓ ターゲットと同じ LM ヘッド(重み共有)
token t+1 の候補分布
→ サンプルした token を1つ遅らせて次ステップの入力へ混ぜる
EAGLE が Medusa と決定的に違う点が2つあります。第一に、Medusa の各ヘッドが h_t だけを見て独立に先を当てるのに対し、EAGLE は先読みヘッドを自己回帰で回すので「1つ先で何を出したか」を次の予測に反映できる。積み上がる不確実性を系列内で扱えるわけです。第二に、EAGLE は次に確定した token を1ステップ遅らせて(shift して)特徴量入力に足し込む設計をとります。特徴量だけを外挿すると「実際にどの token が選ばれたか(サンプリングの結果)」が抜け落ちますが、遅延 token を混ぜることでこのサンプリング由来の不確実性を吸収し、先読み精度を上げています。
token は softmax でサンプリングされた離散結果で、同じ意味でも別語になりうる「揺れ」を含みます。対して直前層の特徴量は連続ベクトルで、文脈の意味がなめらかに変化する信号です。滑らかな系列の外挿は、離散系列の外挿より一般に容易——これが「token を当てにいくより特徴量を当てにいくほうが受理率を稼げる」という EAGLE の主張の骨子です。加えて LM ヘッドを本体と共有するため、特徴量さえ合えば token 分布はターゲットと整合します。
EAGLE の検証は標準の(修正版)棄却サンプリングに載せられ、無損失に運用できます(Medusa の既定 typical acceptance との違い)。EAGLE-2 は、先読みヘッドの信頼度(スコア)に応じてツリーの形を文脈ごとに動的に変え、有望な枝を深く探ることで受理長を伸ばします。さらに後継では、単層特徴だけでなく複数層の情報を使うなどして精度を高める方向へ発展しています。
Medusa と EAGLE を並べて捉える
同じ「別モデルを持たない多トークン予測」でも、候補の作り方が根本的に異なります。
| 観点 | Medusa | EAGLE |
|---|---|---|
| 先読みの単位 | token(各ヘッドが t+k を直接予測) | feature(特徴量を自己回帰で外挿し LM ヘッドで token 化) |
| 位置間の依存 | 各ヘッドは h_t のみ参照し独立(先の条件を見ない) | 自己回帰+遅延 token 注入で先の結果を反映 |
| 追加する部品 | 複数の並列ヘッド(残差付き全結合) | 小さな自己回帰ヘッド1つ+本体 LM ヘッドを共有 |
| 無損失性 | 既定は typical acceptance(厳密一致とは限らない) | 棄却サンプリングで無損失運用が可能 |
| 主な発展 | Medusa-2 で本体も共同学習 | EAGLE-2 で動的ツリー、以降で多層情報活用 |
大づかみには、Medusa は「安く広く候補をばらまいてツリーで拾う」、**EAGLE は「滑らかな特徴量を丁寧に外挿して当てにいく」**という対比で覚えると要点を外しません。一般に同規模の追加コストなら EAGLE 系のほうが受理長(=実効速度)を稼ぎやすいと報告される一方、Medusa は構造が単純で実装・学習が軽い、という実務的トレードオフがあります。
ツリー検証:候補を「木」に束ねて1回で確かめる
多トークン予測に共通する武器がツリー検証(tree attention)です。素朴な投機は「候補を1本の列」として検証しますが、先頭で1つ外すとそれ以降が丸ごと無駄になります。そこで各位置で複数候補を持たせ、それらの組み合わせを木(複数経路)として1回の前方計算にまとめて検証します。
鍵は専用の注意マスクです。木の各ノード(候補 token)が「自分の祖先ノードだけ」を参照できるようにマスクを組むと、独立した複数の候補列を1つのバッチ化された系列として同時に流せます。1回の前方計算で全経路の p が並列に手に入り、その中から受理判定を最も長く通る1本を採用します。
1本の列(直列): A → B → C → D ← 先頭Aを外すと全滅
ツリー(複線化): ┌ B1 ┐
A ─┤ ├ … 複数の (B,C,…) 経路を
└ B2 ┘ 1回の前方計算でまとめて検証
専用の注意マスクで「各ノードは祖先のみ参照」を強制する
木を広く深くするほど1ラウンドで探索できる経路が増え、受理長の期待値が伸びます。ただし木を大きくすれば検証1回で処理する token 数(=計算量)も増えるため、演算がまだ余っている範囲で木を広げるのが得です。ここが「無限に広げれば速い」とはならない理由で、EAGLE-2 の動的ツリーは、この木の形を文脈の信頼度に応じて最適化して受理長と計算量のバランスを取る発想にあたります。
多トークン予測も万能ではありません。ヘッド/先読みヘッドの精度が低い、高エントロピーで次が割れる自由生成、あるいは木を広げすぎて検証コストが受理長の利得を上回る場合、追加計算のぶんだけかえって遅くなることがあります。デコーディング戦略 の温度や top-p を上げて分布を平坦化すると受理長が縮む傾向がある点も同じです。導入前に対象タスクで受理長(平均確定トークン数)を実測するのが鉄則です。
まとめ:本体を活かして Decode を割り勘にする
多トークン予測は、Decode がメモリ帯域律速で演算が余る構造を、別モデルなしで突く手法です。Medusa は本体にヘッドを足して複数手先を並列に候補化し、EAGLE は滑らかな特徴量を自己回帰で外挿して本体の LM ヘッドで token 化します。いずれも候補をツリーに束ねて1回の前方計算で一括検証することで、実効速度を稼ぎます。
| 論点 | 実態 | そこから言えること |
|---|---|---|
| 候補の作り方 | Medusa=複数ヘッドで token 直接予測/EAGLE=特徴量を外挿 | 別モデル運用を避けつつ複数トークンを先読みできる |
| EAGLE の勘所 | token より特徴量が滑らか+遅延 token で不確実性を吸収 | 自己回帰で先の条件を反映し受理長を伸ばせる |
| 検証の効率 | ツリー検証+祖先のみ参照する注意マスク | 複線の候補を1回の前方計算で確かめ全滅リスクを下げる |
| 無損失かどうか | EAGLE=棄却サンプリングで無損失/Medusa=既定 typical acceptance | 厳密一致が要るかで方式・設定を選ぶ |
実務では、受理長を実測してヘッド/先読みヘッドの学習と木の形を詰めるのが出発点です。検証の数理は 投機的デコーディング に、Decode がなぜ帯域律速かは KV キャッシュと LLM 推論の最適化 に、受理長を左右するサンプリング設定は デコーディング戦略 に接続しており、あわせて読むと推論高速化の打ち手が線でつながります。
AI/機械学習 Article
多トークン予測(Medusa・EAGLE)を実務で読む
TL;DRは入口です。実際に選ぶ・使う段階では、何を解決するか、何と比較するか、導入後にどこで詰まるかまで見る必要があります。
解決すること
LLM
比較で見る軸
難易度: advanced / カテゴリ: AI/機械学習 / タグ数: 5
導入後に効く点
EAGLE の勘所は、token 列より特徴量列のほうが滑らかで予測しやすいという観察。次に確定した token を1つ遅らせて入力に混ぜ、サンプリングの不確実性を吸収する。EAGLE は棄却サンプリングで無損失、Medusa は既定が typical acceptance で厳密無損失とは限らない。
先に潰すリスク
用語だけ覚えても、設計・実装・運用でどこに効くかを確認しないと判断を誤る。
- 難易度
- advanced
- カテゴリ
- AI/機械学習
- タグ数
- 5
判断チェックリスト
- 自社の用途が「LLM / 推論最適化」に近いか確認する。
- 強みである「1回の前方計算で先の複数トークンを当てにいく手法。Medusa は本体の最終隠れ状態に複数の予測ヘッドを増設し t+2, t+3… を並列に候補化、EAGLE は出力直前の特徴量を自己回帰的に先読みして本体の LM ヘッドで token 化する。」が本当に評価軸になるか確認する。
- 注意点の「用語だけ覚えても、設計・実装・運用でどこに効くかを確認しないと判断を誤る。」を運用で吸収できるか確認する。
- 公開値や仕様値は、対象プラン・対象機種・対象リージョンまで確認する。
- 既存システム、ID、ネットワーク、監視、バックアップとの接続方法を先に洗い出す。
- 小さく試してから、本番移行、権限設計、障害時手順、コスト監視を決める。