神经网络的梯度

考虑到如下有d层的神经网络: ht = ft(ht − 1)  and  y =  ∘ fd ∘ ... ∘ f1(x)

其中t代表神经网络层数。

计算损失关于参数Wt的梯度为: $\frac{\partial\ell}{\partial\mathbf{W}^t}=\frac{\partial\ell_0}{\partial\mathbf{h}^d}\frac{\partial\mathbf{h}^d}{\partial\mathbf{h}^{d-1}}...\frac{\partial\mathbf{h}^{t+1}}{\partial\mathbf{h}^t}\frac{\partial\mathbf{h}^t}{\partial\mathbf{W}^t}$

这里的h都是代表向量,而我们知道,向量关于向量的导数是一个矩阵,因此上述公式除去首尾可以看作是d-t次的矩阵乘法

数值稳定性的常见问题:梯度爆炸、梯度消失

梯度爆炸

这里的梯度爆炸很好理解,举个例子,加入我神经网络的每层的梯度都是1.5,那么在经历100层之后,这个梯度就会来到4 * 1027,这就有可能会导致python浮点数达到一个上限。 ### 梯度爆炸问题 * 值超出值域 + 对16位浮点数尤为严重(6 * 10−5 − 6 * 104) * 对学习率敏感 + 学习率太大->大参数值->更大的梯度 + 学习率太小->模型训练无法得到进展 + 我们可能需要在训练过程中不断调整学习率 (学习率相当于步长,梯度相当于收敛方向) ## 梯度消失 和梯度爆炸类似,假如我神经网络每层的梯度都是0.8,那么在经历100层之后,这个梯度就会来到2 * 10−10,同样会导致python浮点数上限 ### 梯度消失问题 * 梯度值变成0 + 对16位浮点数尤为严重 * 训练没有进展 + 不管如何训练 * 对底部层尤为严重 + 仅仅在顶部层训练好 + 无法让神经网络更深(和浅神经网络没有区别) # 如何让训练更加稳定 * 目标:让梯度值在合理范围内 * 常用方法:将乘法变加法 + ResNet, LSTM * 归一化 + 梯度归一化,梯度裁剪 * 合理的权重初始和激活函数 # 合理的权重初始和激活函数 1. 合理的权重初始 * 将每层的输出和梯度都看作是随机函数,让他们的均值和方差都保持一致。 + 权重初始:在合理值区间里随机初始函数。因为在训练开始的时候容易有数值不稳定,例如远离最优解的地方损失函数表面可能很复杂,最优解附近的表面可能比较平。(使用N(0, 0.01)对小网络可能没啥影响,但是对深度神经网络就行不通了) 这边具体的介绍可以去看李沐老师的动手学深度学习数值稳定性部分,以MLP为例(假设没有激活函数),从数学角度推导了初始方差如何选择即保证正向方差:

    $$n_{t-1}*Var[{W_{i,j}}^t]=1$$
    其中$n_{t-1}$为输入维度即上一层的神经元数量。反向部分这边就不过多介绍,与正向情况类似。
    
* Xavier初始化

    由于难以同时满足$n_t\gamma_t=1$和$n_{t-1}\gamma_t=1$,因此Xavier使得$\gamma_t(n_{t-1}+n_t)/2=1$即$\gamma_t=2/(n_{t-1}+n_t)$。
    例如对当前层进行权重初始时选择正态分布$\mathcal{N}\left(0,\sqrt[]{2/(n_{t-1}+n_t)}\right)$或者均匀分布$\mathscr{U}\left(-\sqrt{6/(n_{t-1}+n_t)},\sqrt{6/(n_{t-1}+n_t)}\right)$。
  1. 激活函数

    前面提到过让训练更稳定需要让每层的输出和梯度的均值和方差保持一致,激活函数所采用的思想也是一样。 检查常用的激活函数,本文以sigmoid,tanh和ReLu函数为例,对这三个激活函数使用泰勒展开: $$\mathrm{sigmoid}(x)=\frac{1}{2}+\frac{x}{4}-\frac{x^3}{48}+O(x^5) \\ \tanh(x)=0+x-\frac{x^3}{3}+O(x^5) \\ relu(x)=0+x\quad\mathrm{for~}x\geq0$$ 我们可以知道,通常在神经网络的训练过程中权重W的值通常在0附近徘徊,因此一个合适的激活函数应当是当x趋近于0时,f(x) = x。由泰勒展开可知,当x=0时,tanh函数和ReLu函数均为0,而sigmoid函数值为1/2。因此,我们需要调整sigmoid函数: 4 × sigmoid(x) − 2 # 总结

  • 当数值过大或过小都会导致数值问题
  • 常发生在深度模型中,因为其会对n个数累乘
  • 合理的权重初始值和激活函数的选取可以提高数值稳定性

声明

本文为笔者的学习分享,如有错误,欢迎指正修改。