背景
反向传播是现在训练模型的重要方法,但是在部分场景下,会遇到不可微分的函数,从而导致梯度传播失败。比如量化里的取整函数。因此,需要对梯度进行估计然后反向传播。
STE(Straight-Through Estimator)是2013年Yoshua Bengio等人针对梯度估计进行了研究,那篇论文提出了几种梯度估计的方法,并推导出了一些理论性质,然后通过实验证明,STE是效果最好的方法。
由于那篇论文很多篇幅在介绍较为复杂的估计方法,且理论推导也极为复杂,效果没有简洁的STE好,因此不对其进行详细介绍。
应用
恒等函数
STE在反向传播理论推导的时候,把不可微的原子函数(比如量化函数里有放缩和取整两部分,其中取整是不可微的原子函数)替换为恒等函数。
这种应用可以大大减少理论推导的难度,但是在代码里应用反向传播的时候不太方便,以pytorch框架为例,可能需要自己手写函数类
torch.autograd.Function
,具体文档可以查看:PyTorch:
Defining New autograd Functions — PyTorch Tutorials 2.5.0+cu124
documentation 。
SG函数
苏神在他的博客(VQ-VAE的简明介绍:量子化自编码器
- 科学空间|Scientific Spaces)中提出了一个函数
sg(stop gradient)
,代表梯度反向传播终止的恒等函数。
所以原来的不可微的原子函数可以写为\(f(x)=x+sg(x'-x)\)
,在前向传播的时候和x'相同,在反向传播的时候梯度等于直接对
x
反向传播梯度。
这个式子也能用于理论推导,但是不如视为恒等函数麻烦,但是在代码方面容易完成,可以使用pytorch的默认反向传播函数,不需要自定义
torch.autograd.Function
(当然自定义也可以完成任务)。
BitNet/bitnet/bitlinear.py at main · kyegomez/BitNet 代码示例:
# w_quant 是 w 量化(不可微)后的值
# STE using detach
w_quant = w + (weight_quant(w) - w).detach()
这里也给出 torch.autograd.Function
的一个示例(OneBit/transformers/src/transformers/models/bitnet.py
at main · xuyuzhuang11/OneBit):
class SignSTEFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return torch.sign(input)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return grad_output * (1.001 - torch.tanh(input) ** 2)
# return grad_output * (1.01 - torch.tanh(input) ** 2)
参考资料
- [1308.3432] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation
- PyTorch: Defining New autograd Functions — PyTorch Tutorials 2.5.0+cu124 documentation
- VQ-VAE的简明介绍:量子化自编码器 - 科学空间|Scientific Spaces
- kyegomez/BitNet: Implementation of "BitNet: Scaling 1-bit Transformers for Large Language Models" in pytorch
- OneBit/transformers/src/transformers/models/bitnet.py at main · xuyuzhuang11/OneBit