混合精度学習と数値安定性(FP16/BF16)
メモリ半減と高速化を、精度を落とさず両取りする鍵が混合精度。FP16 と BF16 の違いと損失スケーリングの原理を押さえれば、学習が NaN で落ちる原因まで見通せます。
- 1.混合精度はビット数を分担する設計。行列積など重い演算を16ビットで速く回し、勾配の累積や重み更新だけは32ビットの「マスター重み」で持って精度を守る。
- 2.FP16 と BF16 はビット配分が違う。FP16 は指数5・仮数10で精度は高いが表現範囲が狭く勾配アンダーフローを起こす。BF16 は指数8・仮数7で範囲が FP32 と同じため学習で好まれる。
- 3.損失スケーリングは FP16 専用の救済策。損失を定数倍して勾配を表現可能域へ押し上げ、更新前に同じ定数で割り戻す。BF16 は範囲が広くこの操作が原則不要。
なぜ混合精度なのか:速度・メモリと精度の綱引き
ディープラーニングの標準は長らく FP32(単精度、32ビット)でした。しかしモデルが巨大化すると、FP32 はメモリも演算コストも重すぎます。そこで登場するのが 混合精度学習(mixed precision training) です。発想は単純で、精度がさほど要らない演算は16ビットで速く回し、精度が死活的に効く一部だけ32ビットで守る という分担です。
16ビット化の利点は3つあります。(1) 重みや活性のメモリが半分になり、より大きなバッチやモデルが載る。(2) メモリ帯域の消費が半減し、帯域律速の処理が速くなる。(3) GPU の Tensor Core など16ビット行列積に特化したハードを使え、行列積のスループットが数倍になります。問題は、ビット数を削ると表現できる数の範囲と精度が狭まり、学習が不安定化したり NaN で落ちたりすることです。この副作用を原理から潰すのが本稿の主題です。
浮動小数点の構造:指数が「範囲」、仮数が「精度」を決める
まず土台です。浮動小数点数は 符号(1) + 指数部 + 仮数部 で構成され、おおまかに 値 = ±(1.仮数) × 2^(指数) と表されます。ここで決定的なのは役割分担です。
- 指数部のビット数 → 表現できる値の「範囲(ダイナミックレンジ)」 を決める。大きいほど、極端に大きい数や 0 に近い小さい数まで表せる。
- 仮数部のビット数 → 同じ桁での「精度(刻みの細かさ)」 を決める。大きいほど、近い2数を区別できる。
3つの形式をビット配分で並べると違いが一目で分かります。
| 形式 | 符号 | 指数部 | 仮数部 | 性質 |
|---|---|---|---|---|
| FP32(単精度) | 1 | 8 | 23 | 範囲広・精度高。基準だが重い |
| FP16(半精度) | 1 | 5 | 10 | 精度はそこそこ・範囲が狭い |
| BF16(bfloat16) | 1 | 8 | 7 | 範囲はFP32と同じ・精度は粗い |
ここが核心です。FP16 と BF16 は同じ16ビットでも、ビットの配り方が逆の思想になっています。FP16 は仮数に10ビット割いて精度を優先した代わり、指数が5ビットしかなく範囲が極端に狭い。BF16 は指数を8ビット確保してFP32 と同じ範囲を持つ代わり、仮数が7ビットしかなく精度が粗い。この一点が、後述する学習挙動の差をすべて説明します。
FP16 が表せる正の最小の正規化数はおよそ 6e-5、最大はおよそ 65504 です。これより 0 に近い値は アンダーフローして 0 に丸められ、これより大きい値は オーバーフローして無限大 になります。一方 BF16 は指数部が FP32 と同じ8ビットなので、最大はおよそ 3e38、最小の正規化数はおよそ 1e-38 と、FP32 とほぼ同じ範囲をカバーします。精度(刻み)は粗いまま、です。
混合精度の基本骨格:マスター重みと FP32 累積
「16ビットで全部やる」と学習は壊れます。理由は2つあり、いずれも仮数の少なさ(精度不足)に起因します。
第一に 重み更新の消失 です。更新式は θ ← θ - η·g ですが、学習後半は η·g が重み θ に対して桁違いに小さくなります。16ビットの粗い刻みでは 大きな数 + 極小の数 = 大きな数のまま となり、更新が丸めで消えて学習が止まります。第二に、総和や内積で多数の項を16ビットで足し込むと、丸め誤差が累積して精度が崩れます。
そこで標準的な混合精度は次の役割分担を取ります。
# 計算(forward/backward)は16ビットで高速に
fp16_w = cast_to_16bit(master_w) # 重みの16ビット複製
activ = forward(fp16_w, x) # 行列積などはTensor Coreで高速
grad16 = backward(activ) # 勾配も16ビットで算出
# 更新だけはFP32で精度を守る
grad32 = cast_to_fp32(grad16)
master_w = master_w - lr * grad32 # FP32のマスター重みを更新
ポイントは FP32 の「マスター重み(master weights)」を常に正本として保持し、毎ステップそこから16ビット複製を作って計算に使うことです。更新(足し込み)は精度の高い FP32 上で行うため、微小な更新も消えません。加えて、ソフトマックスやバッチ統計、損失の総和といった総和・指数を含む数値的に敏感な演算は FP32 で実行するのが定石です(正規化層 の統計計算も同様の配慮が要ります)。
勾配アンダーフロー:FP16 が学習で詰まる本当の理由
FP16 固有の最大の落とし穴が 勾配アンダーフロー です。逆伝播で計算される勾配は、深いネットワークでは非常に小さな値になりがちです(この縮小の連鎖は 勾配消失と残差接続 と地続きの現象です)。勾配の多くが FP16 の正規化最小値 6e-5 を下回ると、それらは一律 0 に丸められます。
勾配が 0 になったパラメータは、その層が学習に寄与しているにもかかわらず更新を受けません。これは 勾配降下法 の前提を静かに破壊します。重要なのは、問題は値が小さいこと自体ではなく、FP16 の表現範囲の「床」に当たって情報が消えることだという点です。BF16 なら同じ小ささでも 1e-38 まで表せるので、まず床に当たりません。
損失スケーリング:勾配を「表現可能域」へ押し上げる
FP16 でこのアンダーフローを回避する標準手法が 損失スケーリング(loss scaling) です。原理は微分の線形性そのものです。損失 L を定数 S 倍してから逆伝播すると、連鎖律によりすべての勾配が一様に S 倍されます。
scaled_loss = L * S # 損失をS倍(例 S = 1024 や 65536)
grad16 = backward(scaled_loss) # 勾配がすべてS倍され、FP16の床から脱出
grad32 = cast_to_fp32(grad16)
grad32 = grad32 / S # 更新前にSで割り戻し、本来の勾配に復元
master_w = master_w - lr * grad32
S を掛けることで、本来 6e-5 未満で 0 に潰れていた小さな勾配が FP16 の表現域に持ち上がり、丸めを免れます。逆伝播が終わったら重み更新の前に必ず S で割り戻すので、最終的な更新量は数学的に変わりません。やっていることは「小さな数を一旦大きくして、消えないうちに計算し、最後に元へ戻す」だけです。
S が小さすぎると勾配が床に当たってアンダーフローが残り、大きすぎると今度は勾配が FP16 の天井 65504 を超えてオーバーフロー(Inf/NaN) します。固定値の調整は難しいため、実務では 動的損失スケーリング(dynamic loss scaling) が使われます。大きめの S から始め、勾配に Inf/NaN が出たそのステップは更新を破棄して S を半減、一定回数 Inf/NaN が出なければ S を倍増、という自動調整で範囲の中央に張り付かせます。
なぜ BF16 が学習で好まれるのか
ここまでを踏まえると、近年の大規模学習で BF16 が標準 になった理由は明快です。BF16 は指数部が8ビットで FP32 と同じ範囲を持つため、勾配アンダーフローも活性のオーバーフローも原理的にほぼ起きません。結果として、損失スケーリングという厄介な調整機構が原則不要になります。動的スケーリングの実装やチューニングから解放される実務的価値は大きいものです。
代償は精度です。BF16 の仮数は7ビットしかなく、FP16(10ビット)より刻みが粗い。しかし学習では、勾配にもともとミニバッチ由来のノイズが乗っており、仮数の粗さは範囲の狭さほど致命的になりません。「精度が少し粗くても更新の方向が概ね合っていればよい」という学習の性質に、BF16 のビット配分がよく噛み合うわけです。だからこそ Transformer など大規模モデルの事前学習(スケーリング則 が支配する領域)では BF16 が事実上の既定になりました。
- FP16 と BF16 の本質差:同じ16ビットでビット配分が逆。FP16=指数5/仮数10(高精度・狭範囲)、BF16=指数8/仮数7(FP32 同等の範囲・低精度)。
- 損失スケーリングの目的と仕組み:FP16 の勾配アンダーフロー回避。損失を
S倍して勾配を表現域へ持ち上げ、更新前にSで割り戻す。BF16 では原則不要。 - マスター重みがなぜ FP32 か:学習後半の微小な更新が16ビットの丸めで消えるのを防ぐため、更新の足し込みは FP32 上で行う。
- BF16 が学習で好まれる理由:範囲が FP32 と同じでアンダーフロー/オーバーフローが起きにくく、損失スケーリングが要らないから。
まとめ:ビットの配り方を理解すれば安定化は怖くない
混合精度学習は「16ビットの速度・省メモリ」と「32ビットの精度」を演算ごとに分担する設計でした。骨格は、計算は16ビット・更新は FP32 のマスター重み、敏感な総和は FP32、です。そして FP16 と BF16 の挙動差は、指数部(範囲)と仮数部(精度)のビット配分の違い 一点に還元されます。FP16 は精度を取って範囲を失ったがゆえに勾配アンダーフローを招き、その救済策が損失スケーリング。BF16 は範囲を FP32 と揃えたがゆえにスケーリング不要で、粗い精度は学習のノイズ耐性に吸収される——この対比を押さえれば、学習が NaN で落ちたときに「範囲の問題か、精度の問題か」を切り分けられます。最適化の側からの安定化は 最適化アルゴリズムの系統 と合わせて読むと立体的に理解できます。
AI/機械学習 Article
混合精度学習と数値安定性(FP16/BF16)を実務で読む
TL;DRは入口です。実際に選ぶ・使う段階では、何を解決するか、何と比較するか、導入後にどこで詰まるかまで見る必要があります。
解決すること
混合精度
比較で見る軸
難易度: advanced / カテゴリ: AI/機械学習 / タグ数: 5
導入後に効く点
FP16 と BF16 はビット配分が違う。FP16 は指数5・仮数10で精度は高いが表現範囲が狭く勾配アンダーフローを起こす。BF16 は指数8・仮数7で範囲が FP32 と同じため学習で好まれる。
先に潰すリスク
用語だけ覚えても、設計・実装・運用でどこに効くかを確認しないと判断を誤る。
- 難易度
- advanced
- カテゴリ
- AI/機械学習
- タグ数
- 5
判断チェックリスト
- 自社の用途が「混合精度 / FP16」に近いか確認する。
- 強みである「混合精度はビット数を分担する設計。行列積など重い演算を16ビットで速く回し、勾配の累積や重み更新だけは32ビットの「マスター重み」で持って精度を守る。」が本当に評価軸になるか確認する。
- 注意点の「用語だけ覚えても、設計・実装・運用でどこに効くかを確認しないと判断を誤る。」を運用で吸収できるか確認する。
- 公開値や仕様値は、対象プラン・対象機種・対象リージョンまで確認する。
- 既存システム、ID、ネットワーク、監視、バックアップとの接続方法を先に洗い出す。
- 小さく試してから、本番移行、権限設計、障害時手順、コスト監視を決める。