自動微分(Autograd)の仕組み
あなたが`loss.backward()`と書くだけで全パラメータの勾配が返る理由がここに。連鎖律を機械的に積み上げる自動微分の内部を、前進・逆伝播モードと計算グラフから正確に解く。
- 1.自動微分は連鎖律を機械的に適用して厳密な勾配を計算する技術。数値微分の丸め誤差も、数式微分の式爆発も避けられます。
- 2.逆伝播モードは出力1つに対し全入力の勾配を1回の逆走査で得るため、損失1つ・パラメータ多数の深層学習に最適。前進モードは逆。
- 3.PyTorch は実行と同時にテープへ演算を記録する動的グラフ、JAX は関数をトレースして変換する静的な方式で、いずれも連鎖律を自動化します。
そもそも「微分を自動で」とはどういうことか
深層学習の学習は、損失をパラメータで微分した勾配に沿ってパラメータを更新する作業の繰り返しです(勾配降下法)。問題は、現代のモデルが数億〜数千億個のパラメータを持ち、しかも合成関数が何百層も積み重なる点にあります。これを人手で微分するのは不可能です。
プログラムの微分を得る方法は、大きく「手計算・記号微分・数値微分・自動微分」の4つに整理できます。よく混同されますが、自動微分(automatic differentiation, AD)は他の3つのどれとも異なる独自のアプローチです。
| 手法 | やり方 | 弱点 |
|---|---|---|
| 手計算(数式微分) | 人間が導関数の式を導く | 層が深いと式が爆発、ミスも出る |
| 記号微分 | 数式処理系が導関数の式を生成 | 式が指数的に膨張(expression swell) |
| 数値微分 | (f(x+h)−f(x))/h で近似 | 丸め誤差と打ち切り誤差、入力次元の数だけ評価が必要 |
| 自動微分(AD) | 各基本演算の微分を連鎖律で合成 | グラフ構築のメモリ・実装の手間 |
自動微分の核心は、プログラムを「加算・乗算・exp・sin などの基本演算の連なり」とみなし、各演算の局所的な微分を連鎖律で機械的に掛け合わせる点にあります。数式としての導関数を「書き下す」のではなく、特定の入力点での勾配の値を、元の計算と同じ計算量オーダーで求めます。だから記号微分のような式爆発も、数値微分のような近似誤差も起きません。得られる勾配は丸め誤差の範囲で厳密です。
合成関数 y = f(g(x)) の微分は dy/dx = f'(g(x)) · g'(x)。自動微分は、プログラム全体を無数の基本演算の合成とみなし、この掛け算をひたすら積み上げているだけです。各演算の局所微分(ヤコビアン)さえ分かれば、全体の勾配は機械的に得られます。
計算グラフ:プログラムを微分可能な構造に変える
自動微分の出発点は 計算グラフ(computational graph) です。z = (x · y) + sin(x) のような式を、ノード=演算、エッジ=データの流れとして表現します。
x ──┬──► [×] ──► a ──┐
│ ▲ ▼
y ──┘ │ [+] ──► z
│ ▲
x ──────► [sin] ─► b ─┘
各ノードは自分の出力を入力で微分した局所的な値を知っています。[×] ノードなら a = x·y なので ∂a/∂x = y、∂a/∂y = x。[sin] ノードなら ∂b/∂x = cos(x)。全体の勾配 ∂z/∂x は、x から z へ至るすべての経路の局所微分を掛け、経路ごとに足し合わせることで得られます(多変数連鎖律)。上の例なら経路は2本あり、∂z/∂x = y · 1 + cos(x) · 1。
つまり一度グラフを組んでしまえば、勾配計算は「どの順序でこの掛け算と足し算をたどるか」という、純粋にグラフ走査の問題に帰着します。ここで走査の向きが2つあり、それが前進モードと逆伝播モードの分かれ目です。
前進モードと逆伝播モード:走査の向きが計算量を決める
| 観点 | 前進モード(forward) | 逆伝播モード(reverse) |
|---|---|---|
| 走る向き | 入力→出力(順方向と同時) | 出力→入力(順走査の後に逆走査) |
| 1回で得るもの | 1入力に対する全出力の微分 | 1出力に対する全入力の微分 |
| 得意な形 | 入力が少なく出力が多い | 入力が多く出力が1つ(=損失) |
| 中間値の保存 | ほぼ不要 | 順走査の値を全部保持する必要あり |
| 深層学習での使用 | ヤコビアン・ベクトル積などで補助的 | 主役(誤差逆伝播そのもの) |
前進モードは、各変数に「値」と「その変数を入力 x で微分した接ベクトル(微係数)」をペアで持たせ、計算を順方向に進めながら微分も同時に伝播させます。x を1つ動かしたときに全出力がどう動くか(ヤコビアンの1列)を1パスで得ます。入力次元 n、出力次元 m なら、全勾配を得るには入力ごとに n 回パスが必要です。
逆伝播モードは逆です。まず順方向に計算して全中間値を記録し、次に出力側から ∂出力/∂各ノード(随伴, adjoint)を逆向きに伝播させます。1回の逆走査で、1つの出力に対する**全入力の勾配(ヤコビアンの1行)**が得られます。
深層学習では n=パラメータ数が数億、m=損失はスカラー1つ。だから逆伝播モードが圧倒的に有利です。前進モードなら数億回パスが必要なところを、逆伝播モードは順1回+逆1回で全勾配を出します。これが「ニューラルネットの学習=逆伝播(backpropagation)」である根本理由で、誤差逆伝播は逆伝播モード自動微分の特殊例にすぎません(土台はニューラルネットワーク)。
逆走査では各演算の局所微分を計算するのに順走査時の中間値が要ります(例: y=exp(x) の微分 exp(x) は順走査の出力を再利用)。そのため中間値を全部保持せねばならず、層が深いほどメモリを食います。学習時にバッチサイズや層数でメモリが逼迫するのは、活性化値に加えてこの勾配計算用の中間値の保持が効いているためです。再計算でメモリを節約する勾配チェックポイント(gradient checkpointing)はこのトレードオフへの対処です。
テープ方式:PyTorch はグラフを実行時に動的に組む
PyTorch の autograd は テープベース(tape-based) の動的計算グラフです。「テープ」とは、順方向の演算を実行しながら、その記録を時系列に書き残していく仕組みを指します。
import torch
x = torch.tensor(2.0, requires_grad=True) # 勾配を追跡する印
y = x ** 3 + 2 * x # 演算のたびにグラフに記録
y.backward() # テープを逆向きに再生し勾配計算
print(x.grad) # dy/dx = 3x^2 + 2 = 14.0
ポイントは、Python のコードを普通に実行するだけでグラフが構築されることです。requires_grad=True のテンソルが関わる演算ごとに、PyTorch は結果テンソルへ grad_fn(その演算の逆向き微分を知る関数オブジェクト)を貼り付けます。こうして演算が連なると、出力から入力へさかのぼれる逆向きの関数の連鎖が自然にできあがります。これがテープの実体です。
backward() を呼ぶと、出力ノードの随伴を 1 で初期化し、grad_fn を逆順にたどって各テンソルの .grad に勾配を累積します。if 文やループでグラフの形が毎回変わってもよい(define-by-run)のが動的グラフの強みで、可変長系列や条件分岐の多いモデルを自然に書けます。
PyTorch は .grad を上書きせず足し込みます。これはミニバッチを分割して勾配を合算する用途に便利な仕様ですが、学習ループで毎ステップ optimizer.zero_grad()(または .grad = None)を呼ばないと、前ステップの勾配が残り続け、事実上の学習率が暴れて発散します。「学習が不安定」「lossがNaN」の典型原因の一つです。
トレース方式:JAX は関数を変換して微分する
JAX は思想が異なります。テープに「記録」するのではなく、grad という関数変換で、元の関数 f から「勾配を返す新しい関数 f'」を生成します。
import jax
import jax.numpy as jnp
def f(x):
return x ** 3 + 2 * x
df = jax.grad(f) # f を微分した“別の関数”を合成
print(df(2.0)) # 14.0
JAX は f を一度トレースして中間表現(jaxpr)という純粋な計算グラフに落とし、それに対し連鎖律を適用して微分版のグラフを構築します。関数が純粋(副作用なし・同じ入力なら同じ出力)である前提を置くことで、jit(コンパイル)・vmap(自動ベクトル化)・grad を自由に合成できます。jax.grad(jax.grad(f)) で2階微分、と変換を重ねられるのもこの設計ゆえです。
| 観点 | PyTorch autograd | JAX |
|---|---|---|
| グラフ構築 | 実行と同時に記録(define-by-run) | 関数をトレースして生成(define-then-run寄り) |
| API の発想 | テンソルが勾配を持ち back する | 関数を grad で変換し別関数を得る |
| 制御フロー | Python の if/for がそのまま使える | トレース可能な形(lax.cond等)が必要 |
| 高階微分 | 可能だがやや手数 | 変換の合成で自然に表現 |
| 最適化 | eager中心、コンパイルは別途 | jit と一体で関数全体を最適化しやすい |
どちらも裏でやっているのは逆伝播モード自動微分による連鎖律の自動適用で、本質は同じです。違いは「グラフをいつ・どう作るか」という戦略にあり、動的で書きやすい PyTorch と、純粋関数を変換して最適化しやすい JAX、という得手不得手の差として現れます。
まとめ:backward() の裏側で起きていること
| 論点 | 実態 | そこから言えること |
|---|---|---|
| 自動微分の正体 | 基本演算の局所微分を連鎖律で合成 | 厳密かつ元の計算と同オーダーで勾配が出る |
| なぜ逆伝播か | 損失1・パラメータ多数だと逆向きが効率的 | 順1回+逆1回で全勾配。深層学習の標準 |
| コストの所在 | 逆走査に順走査の中間値が必要 | 計算量より先にメモリが律速になりやすい |
| フレームワーク差 | PyTorch=動的テープ / JAX=関数変換 | 書きやすさと最適化のしやすさのトレードオフ |
loss.backward() や jax.grad(f) という一行の裏では、プログラムが基本演算の計算グラフに分解され、連鎖律が出力から入力へ機械的に積み上げられています。手で微分式を書く必要も、数値近似で誤差に悩む必要もありません。この仕組みを理解しておくと、メモリが足りない理由、勾配のゼロクリアが要る理由、高階微分が書ける理由が一本の線でつながり、ディープラーニングの学習がなぜ動くのかという問いに、原理から答えられるようになります。最適化アルゴリズムが受け取るのは、まさにこの自動微分が吐き出した勾配です。
AI/機械学習 Article
自動微分(Autograd)の仕組みを実務で読む
TL;DRは入口です。実際に選ぶ・使う段階では、何を解決するか、何と比較するか、導入後にどこで詰まるかまで見る必要があります。
解決すること
自動微分
比較で見る軸
難易度: advanced / カテゴリ: AI/機械学習 / タグ数: 5
導入後に効く点
逆伝播モードは出力1つに対し全入力の勾配を1回の逆走査で得るため、損失1つ・パラメータ多数の深層学習に最適。前進モードは逆。
先に潰すリスク
用語だけ覚えても、設計・実装・運用でどこに効くかを確認しないと判断を誤る。
- 難易度
- advanced
- カテゴリ
- AI/機械学習
- タグ数
- 5
判断チェックリスト
- 自社の用途が「自動微分 / 誤差逆伝播」に近いか確認する。
- 強みである「自動微分は連鎖律を機械的に適用して厳密な勾配を計算する技術。数値微分の丸め誤差も、数式微分の式爆発も避けられます。」が本当に評価軸になるか確認する。
- 注意点の「用語だけ覚えても、設計・実装・運用でどこに効くかを確認しないと判断を誤る。」を運用で吸収できるか確認する。
- 公開値や仕様値は、対象プラン・対象機種・対象リージョンまで確認する。
- 既存システム、ID、ネットワーク、監視、バックアップとの接続方法を先に洗い出す。
- 小さく試してから、本番移行、権限設計、障害時手順、コスト監視を決める。