背景
反向传播是现在训练模型的重要方法,但是在部分场景下,会遇到不可微分的函数,从而导致梯度传播失败。比如量化里的取整函数。因此,需要对梯度进行估计然后反向传播。
STE(Straight-Through Estimator)是2013年Yoshua Bengio等人针对梯度估计进行了研究,那篇论文提出了几种梯度估计的方法,并推导出了一些理论性质,然后通过实验证明,STE是效果最好的方法。
由于那篇论文很多篇幅在介绍较为复杂的估计方法,且理论推导也极为复杂,效果没有简洁的STE好,因此不对其进行详细介绍。
应用
恒等函数
STE在反向传播理论推导的时候,把不可微的原子函数(比如量化函数里有放缩和取整两部分,其中取整是不可微的原子函数)替换为恒等函数。
这种应用可以大大减少理论推导的难度,但是在代码里应用反向传播的时候不太方便,以pytorch框架为例,可能需要自己手写函数类
torch.autograd.Function
,具体文档可以查看:PyTorch:
Defining New autograd Functions — PyTorch Tutorials 2.5.0+cu124
documentation 。
SG函数
苏神在他的博客(VQ-VAE的简明介绍:量子化自编码器
- 科学空间|Scientific Spaces)中提出了一个函数
sg(stop gradient)
,代表梯度反向传播终止的恒等函数。
所以原来的不可微的原子函数可以写为\(f(x)=x+sg(x'-x)\)
,在前向传播的时候和x'相同,在反向传播的时候梯度等于直接对
x
反向传播梯度。
这个式子也能用于理论推导,但是不如视为恒等函数麻烦,但是在代码方面容易完成,可以使用pytorch的默认反向传播函数,不需要自定义
torch.autograd.Function
(当然自定义也可以完成任务)。
BitNet/bitnet/bitlinear.py at main · kyegomez/BitNet 代码示例:
1 | # w_quant 是 w 量化(不可微)后的值 |
这里也给出 torch.autograd.Function
的一个示例(OneBit/transformers/src/transformers/models/bitnet.py
at main · xuyuzhuang11/OneBit):
1 | class SignSTEFunc(torch.autograd.Function): |
参考资料
- [1308.3432] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation
- PyTorch: Defining New autograd Functions — PyTorch Tutorials 2.5.0+cu124 documentation
- VQ-VAE的简明介绍:量子化自编码器 - 科学空间|Scientific Spaces
- kyegomez/BitNet: Implementation of "BitNet: Scaling 1-bit Transformers for Large Language Models" in pytorch
- OneBit/transformers/src/transformers/models/bitnet.py at main · xuyuzhuang11/OneBit
Since the comment system relies on GitHub's Discussions feature, by default, commentators will receive all notifications. You can click "unsubscribe" in the email to stop receiving them, and you can also manage your notifications by clicking on the following repositories: bg51717/Hexo-Blogs-comments