为了实现反向传播,前向传播时需要计算并保存一些必要的“中间值”。
本文将详细讨论在 Transformer 架构的前向传播过程中,具体需要保存哪些中间值。
基本原则
核心原则:反向传播时,某个梯度公式如果要用到前向里的某个“中间值”,这个“中间值”就要暂存。
以线性层举例。对于
y=xW
反向传播:
∂W∂L=x⊤∂y∂L
∂x∂L=∂y∂LW⊤
因此为了算 ∂L/∂W,需要保存输入 x。
这些值是用 GPT 5.5 总结的:
设:
| 符号 |
含义 |
B |
batch size |
T |
sequence length |
D |
hidden size |
h |
attention heads 数 |
d |
每个 head 的维度,d = D / h |
M |
FFN 中间层维度,常见为 4D |
Vocab |
词表大小 |
假设使用普通 eager attention,而不是 FlashAttention,也没有 activation checkpointing。
| 模块 |
前向公式 |
反向需要保存的大中间值 |
常见维度 |
| Embedding |
x=Embed(tokens) |
embedding backward 需要 tokens;x 作为后续 Transformer block 输入通常会被保存或 checkpoint 重算 |
tokens:B×T,x:B×T×D |
| LayerNorm 1 |
u=LN1(x) |
LayerNorm backward 通常需要输入 x 或归一化结果 x^,以及 mean/rstd |
x:B×T×D,x^:B×T×D,mean:B×T,rstd:B×T |
| QKV Linear |
Q=uWQ, K=uWK, V=uWV |
Linear backward 需要输入 u;attention backward 需要 Q,K,V 或通过 recompute 得到它们 |
u:B×T×D,Q:B×h×T×d,K:B×h×T×d,V:B×h×T×d |
| Attention scores |
S=dQK⊤ |
普通 softmax backward 通常不必须保存 S;某些实现可能保存 mask 后 scores;FlashAttention 类实现通常不保存完整 S |
S:B×h×T×T |
| Attention mask,可选 |
Smask=S+Mattn |
causal/padding mask 通常可由输入或规则重构,不一定作为大激活保存 |
Mattn:1×1×T×T 或 B×1×1×T 或 B×1×T×T |
| Softmax |
P=softmax(Smask) |
标准 eager attention 通常保存 softmax 概率 P,用于 softmax backward;FlashAttention 不保存完整 P,而保存较小统计量并在 backward 重算 |
P:B×h×T×T |
| Attention Dropout,可选 |
P~=Dropout(P)=1−pm⊙P |
严谨地说,softmax backward 需要 P,dropout backward 需要 mask m;P~ 可由 P,m 重构,也可能被实现直接保存 |
P:B×h×T×T,m:B×h×T×T,P~:B×h×T×T |
| Attention Value 聚合 |
O=P~V |
matmul backward 需要 P~ 和 V;P~ 可由 P,m 重构;后续输出投影需要 O reshape 后的 c 作为输入 |
P~:B×h×T×T,V:B×h×T×d,O:B×h×T×d,c:B×T×D |
| Attention 输出投影 |
a=cWO |
Linear backward 需要输入 c;通常不因该 Linear 本身必须保存输出 a |
c:B×T×D,a:B×T×D |
| Residual Add 1 |
x′=x+a |
residual add 本身通常不需要保存大中间值;但 x′ 作为后续 LayerNorm/FFN 输入通常会被保存或重算 |
x′:B×T×D |
| LayerNorm 2 |
v=LN2(x′) |
LayerNorm backward 通常需要输入 x′ 或归一化结果 x^′,以及 mean/rstd |
x′:B×T×D,x^′:B×T×D,mean:B×T,rstd:B×T |
| FFN Linear 1,GELU-FFN |
z=vW1+b1 |
Linear backward 需要输入 v;GELU backward 需要 z 或等价中间值 |
v:B×T×D,z:B×T×M |
| GELU 激活 |
g=GELU(z) |
GELU backward 需要 z 或等价中间值;FFN Linear 2 backward 需要输入 g |
z:B×T×M,g:B×T×M |
| FFN Linear 1,SwiGLU-FFN,可选 |
z1=vWgate, z2=vWup |
两个 Linear backward 需要输入 v;SwiGLU backward 需要 z1,z2 或等价中间值 |
v:B×T×D,z1:B×T×M,z2:B×T×M |
| SwiGLU 激活,可选 |
g=SiLU(z1)⊙z2 |
SwiGLU backward 需要 z1,z2 或等价中间值;后续 Linear backward 需要输入 g |
z1:B×T×M,z2:B×T×M,g:B×T×M |
| FFN Linear 2 |
y=gW2+b2 |
Linear backward 需要输入 g;该 Linear 本身不必须保存输出 y |
g:B×T×M,y:B×T×D |
| FFN Dropout,可选 |
y~=Dropout(y)=1−pmffn⊙y |
dropout backward 需要 mask mffn;若后续需要,可保存或重算 y,y~ |
y:B×T×D,mffn:B×T×D,y~:B×T×D |
| Residual Add 2 |
out=x′+y~ |
residual add 本身通常不需要额外保存大中间值;out 作为下一层输入通常会被保存或 checkpoint 重算 |
out:B×T×D |
| LM Head / Classifier |
ℓ=xfinalWvocab |
Linear backward 需要输入 xfinal;普通 cross entropy 可能保存 logits;fused cross entropy 可能不完整保存 logits |
xfinal:B×T×D,ℓ:B×T×Vocab |
各层分析
线性层
对一个线性层:
Y=XW
反向传播计算:
∂W∂L=X⊤∂Y∂L
∂X∂L=∂Y∂LW⊤
所以 backward 需要:
- 输入 X
- 权重 W,但权重属于参数,不算 activation
- 上游传回来的梯度 ∂Y∂L
所以,保存的中间值只有输入 X。
Dropout 层
Dropout(本质就是掩码) 的反向传播通常只需要保存掩码:
Y=X⊙M
反向传播计算:
∂X∂L=∂Y∂L⊙M
所以 backward 需要:
- 上游传回来的梯度∂Y∂L
- 掩码 M
所以保存的中间值只有掩码 M。
残差层
Residual add:
Z=X+Y
反向传播计算:
∂X∂L=∂Z∂L
∂Y∂L=∂Z∂L
所以 backward 需要:
- 上游传回来的梯度∂Z∂L
残差层没有需要保存的中间值。
LayerNorm 层
LayerNorm 形式大致是:
y=γσ2+ϵx−μ+β
反向传播时,要计算:
∂x∂L,∂γ∂L,∂β∂L
因此 backward 至少需要知道 forward 的输入 x,以及通常还会用到均值、方差或 rstd 之类的统计量。
但是值得注意的是,里面内存占用占大头的是 x,所以中间值一般保存 x。
GeLU 激活
GeLU 是逐元素非线性,例如:
y=GeLU(x)
它的 backward 是:
∂x∂L=∂y∂L⋅GeLU′(x)
∂Z∂L 是上游传回的梯度。注意这里的导数依赖于 forward 的输入 x。
所以为了 backward,GeLU 需要保存的中间值是 x。
Softmax 激活
Softmax 是向量上的归一化操作,例如:
y=Softmax(z),yi=∑jezjezi
它的 backward 是:
∂z∂L=y⊙(∂y∂L−⟨∂y∂L,y⟩)
其中 ∂y∂L 是上游传回的梯度。注意这里的导数依赖于 forward 的输出 y。
所以为了 backward,Softmax 需要保存的中间值是 y。
评论