LLM 结构和创新点
优化技术
大模型领域针对 Transformer 计算过程的一些优化
滑动窗口注意力
sliding window attention: Longformer
稀疏注意力
Generating Long Sequences with Sparse Transformers,核心是只让设置好的像素点参与自注意力的计算(注意这里不是只选取设置好位置上的像素点,其他mask掉,因为这样并不能降低模型的复杂度),引入一个名为连接模式(Connectivity Pattern)的变量,连接模式只作用在$K$和$V$的计算上,(bs, head_num, seq_len, head_dim) 中的 seq_len 只选取连接模式中选择的。
连接模式(注意力核)的选择:局部稀疏模式、分块稀疏模式、对角块稀疏模式
Flash Attention
其核心思想是将原始的注意力矩阵分解成更小的子矩阵,然后分别对这些子矩阵进行计算,只要这个子矩阵的大小可以在 SRAM 内存放,那么不就可以在计算过程中只访问 SRAM 了。
采用了 Recomputation (重算)方法,这算是在深度学习优化中的老概念了,它是一种算力换内存的把戏。
Tiling 方法将 NxN 的 softmax 分数矩阵划分为块,核心是 分块SoftMax算法,详细推导过程见技术博客。
Qwen2 模块
modeling_qwen2.py
架构
-
Qwen2RMSNorm: RMS归一化层
-
Qwen2RotaryEmbedding: 旋转位置编码
-
Attention
-
Qwen2Attention: 注意力层
-
Qwen2FlashAttention2: 使用Flash Attention 2.0版本加速的注意力层
-
Qwen2SdpaAttention: 使用Sdpa(pytorch自带的加速, Scaled Dot-Product Attention)加速的注意力层
-
-
Qwen2DecoderLayer: 编码层,核心结构,之后就是堆叠
-
Qwen2PreTrainedModel: 预训练类
-
Qwen2Model: 不带head的Qwen2模型
-
Qwen2ForCausalLM: 带Causal LM head的Qwen2模型
-
Qwen2ForSequenceClassification: 带序列分类头的Qwen2模型
功能函数
-
_get_unpad_data: 在flash attention的数据预处理中会用到。主要是对 attention mask 进行一些操作;
-
rotate_half: 在旋转位置编码中用到;
-
apply_rotary_pos_emb: 对数据主要是注意力运算中的 q,k 做旋转位置编码;
-
repeat_kv: 主要是在 MQA(Multi-Query Attention)和 GQA(Group-Query Attention)中用到,因为 q head 数量是 k,v head 的数量的整数倍
优化注意力机制,降低处理长序列时的计算复杂度。
分组查询注意力 GQA
MQA 将所有注意力头的键和值共享
GQA 的方式是将多个注意力头分成组,每组头共享同一组的键和值。
实现方式:transformers.models.llama.modeling_llama.repeat_kv
复制多次 KV,使用 expand
而不是 repeat
是因为 KV 的参数是组内共享的
旋转位置编码 PoRE
通过旋转编码,使得每个 token 既有相对位置信息,又有绝对位置信息,Qwen2 的位置信息编码是在 attention 中计算到 KV 中的
例如,对于一个维度为 512 的向量$v=(v_0,v_1,v_2,v_3,…,v_{511})$,RoPE 可能将其看作:
- 第 1 对:$(v_0,v_1)$
- 第 2 对:$(v_2,v_3)$
- …
- 第 256 对:$(v_{510},v_{511})$
然后,通过旋转矩阵分别对每一对进行旋转,使得每一对的旋转角度与相对位置相对应。
代码层面,Qwen2 构造 Qwen2RotaryEmbedding
类和 apply_rotary_pos_emb
方法, Qwen2RotaryEmbedding
类基于 seq_len 返回缓存的 cos 和 sin 数据, apply_rotary_pos_emb
方法将位置信息作用到$K$和$V$上。
均方根归一化 RMSNorm
矩阵计算要进行归一化防止不同特征的取值过大,常用的是 layer norm,也就是每一项减去样本的均值,再除以样本的方差;而 RMS 则是去除了减去均值的操作,以便提升效率。
- LayerNorm 通过均值和标准差对输入进行标准化,确保输入具有零均值和单位方差,这对稳定训练有帮助。
- RMSNorm 则通过均方根进行归一化,不关注均值,仅对幅值进行规范化,避免了对数据的中心化处理,减少了计算复杂度,并且在某些模型中表现优越。
$\text{LN}(x) = \frac{x - \mu}{\sigma} \cdot \gamma + \beta$
- x 是输入向量(通常是模型中每一层的输出)。
- $\mu$ 是输入向量的均值。
- $\sigma$ 是输入向量的标准差。
- $\gamma$ 和 $\beta$ 是可学习的缩放和偏移参数。
LN 会计算整个输入的均值 $\mu$ 和标准差 $ \sigma $,并用它们对输入进行标准化。
$\text{RMSNorm}(x) = \frac{x}{\text{RMS}(x)} \cdot \gamma$
- $\text{RMS}(x) = \sqrt{\frac{1}{n} \sum_{i=1}^n x_i^2}$是输入向量的均方根;
- $\gamma $ 是可学习的缩放参数;
- 不使用均值和偏移参数 $\beta$ ;
RMSNorm 只关注输入的 范数(幅值),而不考虑均值,避免了对数据的中心化处理。
它是一种更轻量级的标准化方法,在某些情况下可以提高训练效率和收敛速度。
激活函数 SwiGLU
SwiGLU 将 GLU 的门控改为 Swish
,后验证实其在 GLU 的众多变体里效果最好,优于 Transformer 一开始使用的 ReLU
。
$F.silu(self.w1(x)) * self.w2(x)$,其中 silu 公式为 $\mathrm{Swish}_\beta(x)=x\otimes\sigma(\beta x)$