transformer模型的GPU显存使用分析(三):反向传播到底需要哪些中间结果

基本原则

核心原则:反向传播时,某个梯度公式如果要用到前向里的某个“中间值”,这个“中间值”就要暂存。

以线性层举例。对于

$$ y = xW $$

反向传播:

$$ \frac{\partial L}{\partial W} = x^\top \frac{\partial L}{\partial y} $$

$$ \frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} W^\top $$

因此为了算 $\partial L / \partial W $,需要保存输入 $x$[1]

transformer 架构中常存“中间值”汇总

这些值是用 GPT 5.5 总结的。暂时感觉没问题:

设:

符号 含义
B batch size
T sequence length
D hidden size
h attention heads 数
d 每个 head 的维度,d = D / h
M FFN 中间层维度,常见为 4D
Vocab 词表大小

假设使用普通 eager attention,而不是 FlashAttention,也没有 activation checkpointing。

模块 前向公式 反向需要保存的大中间值 常见维度
Embedding $x = \mathrm{Embed}(\mathrm{tokens})$ embedding backward 需要 $\mathrm{tokens}$;$x$ 作为后续 Transformer block 输入通常会被保存或 checkpoint 重算 $\mathrm{tokens}: B \times T$,$x: B \times T \times D$
LayerNorm 1 $u = \mathrm{LN}_1(x)$ LayerNorm backward 通常需要输入 $x$ 或归一化结果 $\hat{x}$,以及 $\mathrm{mean}/\mathrm{rstd}$ $x: B \times T \times D$,$\hat{x}: B \times T \times D$,$\mathrm{mean}: B \times T$,$\mathrm{rstd}: B \times T$
QKV Linear $Q = uW_Q,\ K = uW_K,\ V = uW_V$ Linear backward 需要输入 $u$;attention backward 需要 $Q,K,V$ 或通过 recompute 得到它们 $u: B \times T \times D$,$Q: B \times h \times T \times d$,$K: B \times h \times T \times d$,$V: B \times h \times T \times d$
Attention scores $S = \frac{QK^\top}{\sqrt d}$ 普通 softmax backward 通常不必须保存 $S$;某些实现可能保存 mask 后 scores;FlashAttention 类实现通常不保存完整 $S$ $S: B \times h \times T \times T$
Attention mask,可选 $S_{\mathrm{mask}} = S + M_{\mathrm{attn}}$ causal/padding mask 通常可由输入或规则重构,不一定作为大激活保存 $M_{\mathrm{attn}}: 1 \times 1 \times T \times T$ 或 $B \times 1 \times 1 \times T$ 或 $B \times 1 \times T \times T$
Softmax $P = \mathrm{softmax}(S_{\mathrm{mask}})$ 标准 eager attention 通常保存 softmax 概率 $P$,用于 softmax backward;FlashAttention 不保存完整 $P$,而保存较小统计量并在 backward 重算 $P: B \times h \times T \times T$
Attention Dropout,可选 $\tilde P = \mathrm{Dropout}(P)=\frac{m\odot P}{1-p}$ 严谨地说,softmax backward 需要 $P$,dropout backward 需要 mask $m$;$\tilde P$ 可由 $P,m$ 重构,也可能被实现直接保存 $P: B \times h \times T \times T$,$m: B \times h \times T \times T$,$\tilde P: B \times h \times T \times T$
Attention Value 聚合 $O = \tilde P V$ matmul backward 需要 $\tilde P$ 和 $V$;$\tilde P$ 可由 $P,m$ 重构;后续输出投影需要 $O$ reshape 后的 $c$ 作为输入 $\tilde P: B \times h \times T \times T$,$V: B \times h \times T \times d$,$O: B \times h \times T \times d$,$c: B \times T \times D$
Attention 输出投影 $a = cW_O$ Linear backward 需要输入 $c$;通常不因该 Linear 本身必须保存输出 $a$ $c: B \times T \times D$,$a: B \times T \times D$
Residual Add 1 $x' = x + a$ residual add 本身通常不需要保存大中间值;但 $x'$ 作为后续 LayerNorm/FFN 输入通常会被保存或重算 $x': B \times T \times D$
LayerNorm 2 $v = \mathrm{LN}_2(x')$ LayerNorm backward 通常需要输入 $x'$ 或归一化结果 $\hat{x}'$,以及 $\mathrm{mean}/\mathrm{rstd}$ $x': B \times T \times D$,$\hat{x}': B \times T \times D$,$\mathrm{mean}: B \times T$,$\mathrm{rstd}: B \times T$
FFN Linear 1,GELU-FFN $z = vW_1 + b_1$ Linear backward 需要输入 $v$;GELU backward 需要 $z$ 或等价中间值 $v: B \times T \times D$,$z: B \times T \times M$
GELU 激活 $g = \mathrm{GELU}(z)$ GELU backward 需要 $z$ 或等价中间值;FFN Linear 2 backward 需要输入 $g$ $z: B \times T \times M$,$g: B \times T \times M$
FFN Linear 1,SwiGLU-FFN,可选 $z_1 = vW_{\mathrm{gate}},\ z_2 = vW_{\mathrm{up}}$ 两个 Linear backward 需要输入 $v$;SwiGLU backward 需要 $z_1,z_2$ 或等价中间值 $v: B \times T \times D$,$z_1: B \times T \times M$,$z_2: B \times T \times M$
SwiGLU 激活,可选 $g = \mathrm{SiLU}(z_1)\odot z_2$ SwiGLU backward 需要 $z_1,z_2$ 或等价中间值;后续 Linear backward 需要输入 $g$ $z_1: B \times T \times M$,$z_2: B \times T \times M$,$g: B \times T \times M$
FFN Linear 2 $y = gW_2 + b_2$ Linear backward 需要输入 $g$;该 Linear 本身不必须保存输出 $y$ $g: B \times T \times M$,$y: B \times T \times D$
FFN Dropout,可选 $\tilde y = \mathrm{Dropout}(y)=\frac{m_{\mathrm{ffn}}\odot y}{1-p}$ dropout backward 需要 mask $m_{\mathrm{ffn}}$;若后续需要,可保存或重算 $y,\tilde y$ $y: B \times T \times D$,$m_{\mathrm{ffn}}: B \times T \times D$,$\tilde y: B \times T \times D$
Residual Add 2 $\mathrm{out} = x' + \tilde y$ residual add 本身通常不需要额外保存大中间值;$\mathrm{out}$ 作为下一层输入通常会被保存或 checkpoint 重算 $\mathrm{out}: B \times T \times D$
LM Head / Classifier $\ell = x_{\mathrm{final}}W_{\mathrm{vocab}}$ Linear backward 需要输入 $x_{\mathrm{final}}$;普通 cross entropy 可能保存 logits;fused cross entropy 可能不完整保存 logits $x_{\mathrm{final}}: B \times T \times D$,$\ell: B \times T \times \mathrm{Vocab}$

评论