梯度估计STE


背景

反向传播是现在训练模型的重要方法,但是在部分场景下,会遇到不可微分的函数,从而导致梯度传播失败。比如量化里的取整函数。因此,需要对梯度进行估计然后反向传播。

STE(Straight-Through Estimator)是2013年Yoshua Bengio等人针对梯度估计进行了研究,那篇论文提出了几种梯度估计的方法,并推导出了一些理论性质,然后通过实验证明,STE是效果最好的方法。

由于那篇论文很多篇幅在介绍较为复杂的估计方法,且理论推导也极为复杂,效果没有简洁的STE好,因此不对其进行详细介绍。

应用

恒等函数

STE在反向传播理论推导的时候,把不可微的原子函数(比如量化函数里有放缩和取整两部分,其中取整是不可微的原子函数)替换为恒等函数。

这种应用可以大大减少理论推导的难度,但是在代码里应用反向传播的时候不太方便,以pytorch框架为例,可能需要自己手写函数类 torch.autograd.Function,具体文档可以查看:PyTorch: Defining New autograd Functions — PyTorch Tutorials 2.5.0+cu124 documentation

SG函数

苏神在他的博客(VQ-VAE的简明介绍:量子化自编码器 - 科学空间|Scientific Spaces)中提出了一个函数 sg(stop gradient),代表梯度反向传播终止的恒等函数。

所以原来的不可微的原子函数可以写为\(f(x)=x+sg(x'-x)\) ,在前向传播的时候和x'相同,在反向传播的时候梯度等于直接对 x反向传播梯度。

这个式子也能用于理论推导,但是不如视为恒等函数麻烦,但是在代码方面容易完成,可以使用pytorch的默认反向传播函数,不需要自定义 torch.autograd.Function(当然自定义也可以完成任务)。

BitNet/bitnet/bitlinear.py at main · kyegomez/BitNet 代码示例:

# w_quant 是 w 量化(不可微)后的值
# STE using detach
w_quant = w + (weight_quant(w) - w).detach()

这里也给出 torch.autograd.Function的一个示例(OneBit/transformers/src/transformers/models/bitnet.py at main · xuyuzhuang11/OneBit):

class SignSTEFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return torch.sign(input)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        return grad_output * (1.001 - torch.tanh(input) ** 2)
        # return grad_output * (1.01 - torch.tanh(input) ** 2)

参考资料


文章作者: bg51717
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 bg51717 !
由于评论系统依托于Github的Discuss存在,因此默认评论者会收到所有通知。可以在邮件里点击"unsubscribe"停止接受,后续也可以点击下列仓库进行通知管理: bg51717/Hexo-Blogs-comments
Since the comment system relies on GitHub's Discussions feature, by default, commentators will receive all notifications. You can click "unsubscribe" in the email to stop receiving them, and you can also manage your notifications by clicking on the following repositories: bg51717/Hexo-Blogs-comments
  目录