跳至主要內容

激活函数总结

Genhiy2024年4月17日...大约 3 分钟AI激活函数

Sigmoid

问:为什么大模型会有梯度消失问题?

答:sigmoid函数的导数取值范围是(0, 0.25],小于1的数乘在一起,必然是越乘越小的。这才仅仅是3层,如果10层的话, 根据0.25100.0000009540.25^{10}≈ 0.000000954,第10层的误差相对第一层卷积的参数的梯度将是一个非常小的值,这就是所谓的“梯度消失”。

ReLU函数的改进就是它使得梯度导数的取值范围为(0,1),这样的话只要一条路径上的导数都是1,无论神经网络是多少层,这一部分的乘积都始终为1,因此深层的梯度也可以传递到浅层中。

SwiGLU

大模型基础|激活函数|从ReLU 到SwiGLUopen in new window

Google的PaLM和Meta的LLaMA都使用了SwiGLU来增强Transformer架构中的FFN层(Feed Forward Network)的性能。Transformer模型通过多头注意力层和FFN层交替工作。FFN层存在于Transformer架构的编码器和解码器部分中。

FFN层包括两个线性变换,中间插入一个非线性激活函数。最初的Transformer架构采用了ReLU激活函数。

FFN(x,W1,W2,b1,b2)=ReLU(xW1+b1)W2+b2 \mathrm{FFN}(x,W_1,W_2,b_1,b_2)=\mathrm{ReLU}(xW_1+b_1)W_2+b_2

在他们的实验中使用了不包含bias项的FFN,T5其实也是这么搞的:

FFN(x,W1,W2)=ReLU(xW1)W2 \mathrm{FFN}(x,W_1,W_2)=\mathrm{ReLU}(xW_1)W_2

论文《Gaussian Error Linear Units(GELUs)》提出了GELU,这是ReLU的平滑版本。作者使用了标准正态分布的累积分布函数(cdf)的近似计算来提高计算速度。论文《Swish: a Self-Gated Activation Function》提出了Swish,这也是对带有非零负值梯度的ReLU平滑版本。

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

def gelu(x):
   return x * norm.cdf(x)

def relu(x):
   return np.maximum(0, x)

def swish(x, beta=1):
   return x * (1 / (1 + np.exp(-beta * x)))   

GLU及其变体

GLU(Gated Linear Units)其实不算是一种激活函数,而是一种神经网络层。它是一个线性变换后面接门控机制的结构。其中门控机制是一个sigmoid函数用来控制信息能够通过多少。

GLU(x,W,V,b,c)=σ(xW+b)(xV+c) \mathrm{GLU}(x,W,V,b,c)=\sigma(xW+b)\otimes(xV+c)

其中σ\sigma为sigmoid函数,\otimes为逐元素乘。通过使用其他的激活函数我们就能够得到GLU的各种变体了。

比如说现在LLM中常用的SwiGLU其实就是采用Swish作为激活函数的GLU变体:

SwiGLU(x,W,V,b,c)=Swish1(xW+b)(xV+c) \mathrm{SwiGLU}(x,W,V,b,c)=\mathrm{Swish}_1(xW+b)\otimes(xV+c)

由于引入了更多的权重矩阵,通常会对隐藏层的大小做一个缩放,从而保证整体的参数量不变。代码实现如下:

class FeedForward(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
        super().__init__()
        hidden_dim = multiple_of * ((2 * hidden_dim // 3 + multiple_of - 1) // multiple_of)
        self.w1 = nn.Linear(dim, hidden_dim)
        self.w2 = nn.Linear(hidden_dim, dim)
        self.w3 = nn.Linear(dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))

注:为什么这里使用的是silu激活函数? 因为SiLU其实就是beta为1时的Swish激活函数:f(x)=xσ(x)f(x)=x\cdot\sigma(x)