Transformer 模型 GPU 显存分析(三):反向传播需要保存哪些中间结果?

为了实现反向传播,前向传播时需要计算并保存一些必要的“中间值”[1]

本文将详细讨论在 Transformer 架构的前向传播过程中,具体需要保存哪些中间值。

基本原则

核心原则:反向传播时,某个梯度公式如果要用到前向里的某个“中间值”,这个“中间值”就要暂存。

以线性层举例。对于

y=xWy = xW

反向传播:

LW=xLy\frac{\partial L}{\partial W} = x^\top \frac{\partial L}{\partial y}

Lx=LyW\frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} W^\top

因此为了算 L/W\partial L / \partial W ,需要保存输入 xx[2]

Transformer 架构中需保存的中间值汇总(有问题,暂时不要看)

这些值是用 GPT 5.5 总结的:

设:

符号 含义
B batch size
T sequence length
D hidden size
h attention heads 数
d 每个 head 的维度,d = D / h
M FFN 中间层维度,常见为 4D
Vocab 词表大小

假设使用普通 eager attention,而不是 FlashAttention,也没有 activation checkpointing。

模块 前向公式 反向需要保存的大中间值 常见维度
Embedding x=Embed(tokens)x = \mathrm{Embed}(\mathrm{tokens}) embedding backward 需要 tokens\mathrm{tokens}xx 作为后续 Transformer block 输入通常会被保存或 checkpoint 重算 tokens:B×T\mathrm{tokens}: B \times Tx:B×T×Dx: B \times T \times D
LayerNorm 1 u=LN1(x)u = \mathrm{LN}_1(x) LayerNorm backward 通常需要输入 xx 或归一化结果 x^\hat{x},以及 mean/rstd\mathrm{mean}/\mathrm{rstd} x:B×T×Dx: B \times T \times Dx^:B×T×D\hat{x}: B \times T \times Dmean:B×T\mathrm{mean}: B \times Trstd:B×T\mathrm{rstd}: B \times T
QKV Linear Q=uWQ, K=uWK, V=uWVQ = uW_Q,\ K = uW_K,\ V = uW_V Linear backward 需要输入 uu;attention backward 需要 Q,K,VQ,K,V 或通过 recompute 得到它们 u:B×T×Du: B \times T \times DQ:B×h×T×dQ: B \times h \times T \times dK:B×h×T×dK: B \times h \times T \times dV:B×h×T×dV: B \times h \times T \times d
Attention scores S=QKdS = \frac{QK^\top}{\sqrt d} 普通 softmax backward 通常不必须保存 SS;某些实现可能保存 mask 后 scores;FlashAttention 类实现通常不保存完整 SS S:B×h×T×TS: B \times h \times T \times T
Attention mask,可选 Smask=S+MattnS_{\mathrm{mask}} = S + M_{\mathrm{attn}} causal/padding mask 通常可由输入或规则重构,不一定作为大激活保存 Mattn:1×1×T×TM_{\mathrm{attn}}: 1 \times 1 \times T \times TB×1×1×TB \times 1 \times 1 \times TB×1×T×TB \times 1 \times T \times T
Softmax P=softmax(Smask)P = \mathrm{softmax}(S_{\mathrm{mask}}) 标准 eager attention 通常保存 softmax 概率 PP,用于 softmax backward;FlashAttention 不保存完整 PP,而保存较小统计量并在 backward 重算 P:B×h×T×TP: B \times h \times T \times T
Attention Dropout,可选 P~=Dropout(P)=mP1p\tilde P = \mathrm{Dropout}(P)=\frac{m\odot P}{1-p} 严谨地说,softmax backward 需要 PP,dropout backward 需要 mask mmP~\tilde P 可由 P,mP,m 重构,也可能被实现直接保存 P:B×h×T×TP: B \times h \times T \times Tm:B×h×T×Tm: B \times h \times T \times TP~:B×h×T×T\tilde P: B \times h \times T \times T
Attention Value 聚合 O=P~VO = \tilde P V matmul backward 需要 P~\tilde PVVP~\tilde P 可由 P,mP,m 重构;后续输出投影需要 OO reshape 后的 cc 作为输入 P~:B×h×T×T\tilde P: B \times h \times T \times TV:B×h×T×dV: B \times h \times T \times dO:B×h×T×dO: B \times h \times T \times dc:B×T×Dc: B \times T \times D
Attention 输出投影 a=cWOa = cW_O Linear backward 需要输入 cc;通常不因该 Linear 本身必须保存输出 aa c:B×T×Dc: B \times T \times Da:B×T×Da: B \times T \times D
Residual Add 1 x=x+ax' = x + a residual add 本身通常不需要保存大中间值;但 xx' 作为后续 LayerNorm/FFN 输入通常会被保存或重算 x:B×T×Dx': B \times T \times D
LayerNorm 2 v=LN2(x)v = \mathrm{LN}_2(x') LayerNorm backward 通常需要输入 xx' 或归一化结果 x^\hat{x}',以及 mean/rstd\mathrm{mean}/\mathrm{rstd} x:B×T×Dx': B \times T \times Dx^:B×T×D\hat{x}': B \times T \times Dmean:B×T\mathrm{mean}: B \times Trstd:B×T\mathrm{rstd}: B \times T
FFN Linear 1,GELU-FFN z=vW1+b1z = vW_1 + b_1 Linear backward 需要输入 vv;GELU backward 需要 zz 或等价中间值 v:B×T×Dv: B \times T \times Dz:B×T×Mz: B \times T \times M
GELU 激活 g=GELU(z)g = \mathrm{GELU}(z) GELU backward 需要 zz 或等价中间值;FFN Linear 2 backward 需要输入 gg z:B×T×Mz: B \times T \times Mg:B×T×Mg: B \times T \times M
FFN Linear 1,SwiGLU-FFN,可选 z1=vWgate, z2=vWupz_1 = vW_{\mathrm{gate}},\ z_2 = vW_{\mathrm{up}} 两个 Linear backward 需要输入 vv;SwiGLU backward 需要 z1,z2z_1,z_2 或等价中间值 v:B×T×Dv: B \times T \times Dz1:B×T×Mz_1: B \times T \times Mz2:B×T×Mz_2: B \times T \times M
SwiGLU 激活,可选 g=SiLU(z1)z2g = \mathrm{SiLU}(z_1)\odot z_2 SwiGLU backward 需要 z1,z2z_1,z_2 或等价中间值;后续 Linear backward 需要输入 gg z1:B×T×Mz_1: B \times T \times Mz2:B×T×Mz_2: B \times T \times Mg:B×T×Mg: B \times T \times M
FFN Linear 2 y=gW2+b2y = gW_2 + b_2 Linear backward 需要输入 gg;该 Linear 本身不必须保存输出 yy g:B×T×Mg: B \times T \times My:B×T×Dy: B \times T \times D
FFN Dropout,可选 y~=Dropout(y)=mffny1p\tilde y = \mathrm{Dropout}(y)=\frac{m_{\mathrm{ffn}}\odot y}{1-p} dropout backward 需要 mask mffnm_{\mathrm{ffn}};若后续需要,可保存或重算 y,y~y,\tilde y y:B×T×Dy: B \times T \times Dmffn:B×T×Dm_{\mathrm{ffn}}: B \times T \times Dy~:B×T×D\tilde y: B \times T \times D
Residual Add 2 out=x+y~\mathrm{out} = x' + \tilde y residual add 本身通常不需要额外保存大中间值;out\mathrm{out} 作为下一层输入通常会被保存或 checkpoint 重算 out:B×T×D\mathrm{out}: B \times T \times D
LM Head / Classifier =xfinalWvocab\ell = x_{\mathrm{final}}W_{\mathrm{vocab}} Linear backward 需要输入 xfinalx_{\mathrm{final}};普通 cross entropy 可能保存 logits;fused cross entropy 可能不完整保存 logits xfinal:B×T×Dx_{\mathrm{final}}: B \times T \times D:B×T×Vocab\ell: B \times T \times \mathrm{Vocab}

各层分析

线性层

对一个线性层:

Y=XW Y = XW

反向传播计算:

LW=XLY \frac{\partial L}{\partial W} = X^\top \frac{\partial L}{\partial Y}

LX=LYW \frac{\partial L}{\partial X} = \frac{\partial L}{\partial Y} W^\top

所以 backward 需要:

  • 输入 X X
  • 权重 W W ,但权重属于参数,不算 activation
  • 上游传回来的梯度 LY\frac{\partial L}{\partial Y}

所以,保存的中间值只有输入 XX

Dropout 层

Dropout(本质就是掩码) 的反向传播通常只需要保存掩码:

Y=XM Y = X \odot M

反向传播计算:

LX=LYM \frac{\partial L}{\partial X} = \frac{\partial L}{\partial Y} \odot M

所以 backward 需要:

  • 上游传回来的梯度LY\frac{\partial L}{\partial Y}
  • 掩码 MM

所以保存的中间值只有掩码 MM

残差层

Residual add:

Z=X+Y Z = X + Y

反向传播计算:

LX=LZ \frac{\partial L}{\partial X} = \frac{\partial L}{\partial Z}

LY=LZ \frac{\partial L}{\partial Y} = \frac{\partial L}{\partial Z}

所以 backward 需要:

  • 上游传回来的梯度LZ\frac{\partial L}{\partial Z}

残差层没有需要保存的中间值。

LayerNorm 层

LayerNorm 形式大致是:

y=γxμσ2+ϵ+β y = \gamma \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta

反向传播时,要计算:

Lx,Lγ,Lβ \frac{\partial L}{\partial x}, \quad \frac{\partial L}{\partial \gamma}, \quad \frac{\partial L}{\partial \beta}

因此 backward 至少需要知道 forward 的输入 x x ,以及通常还会用到均值、方差或 rstd 之类的统计量。

但是值得注意的是,里面内存占用占大头的是 xx,所以中间值一般保存 xx[3]

GeLU 激活

GeLU 是逐元素非线性,例如:

y=GeLU(x) y = \text{GeLU}(x)

它的 backward 是:

Lx=LyGeLU(x) \frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} \cdot \text{GeLU}'(x)

LZ\frac{\partial L}{\partial Z} 是上游传回的梯度。注意这里的导数依赖于 forward 的输入 x x

所以为了 backward,GeLU 需要保存的中间值是 xx

Softmax 激活

Softmax 是向量上的归一化操作,例如:

y=Softmax(z),yi=ezijezjy = \text{Softmax}(z),\qquad y_i = \frac{e^{z_i}}{\sum_{j} e^{z_j}}

它的 backward 是:

Lz=y(LyLy,y)\frac{\partial L}{\partial z} = y \odot \left( \frac{\partial L}{\partial y} - \langle \frac{\partial L}{\partial y}, y \rangle \right)

[4]

其中 Ly\frac{\partial L}{\partial y} 是上游传回的梯度。注意这里的导数依赖于 forward 的输出 yy

所以为了 backward,Softmax 需要保存的中间值是 yy

评论