【大模型】从0到1训练一个大模型(六):MOE
混合专家模型(Mixture of Experts,MOE)是一种集成学习模型,它由多个专家网络(Expert Networks)和一个门控网络(Gating Network)组成。每个专家网络是一个独立的神经网络,专门用于处理特定类型的数据或任务。门控网络的作用是根据输入数据,动态地决定每个专家网络的权重,即哪些专家网络应该参与当前输入的处理,以及每个专家网络的贡献程度。
前言
在大模型的发展历程中,我们不断追求更强的性能、更高的效率以及更好的泛化能力。此前,我们已经完成了预训练、监督微调(SFT)以及直接偏好优化(DPO)等关键步骤,逐步提升模型在各种任务中的表现。然而,随着模型规模的不断增大,计算资源的需求也呈指数级增长,这给模型的训练和部署带来了巨大挑战。为了应对这些挑战,混合专家模型(Mixture of Experts,MOE)应运而生。MOE 模型通过引入多个专家网络和门控机制,在提升模型表达能力的同时,能够有效降低计算成本,为大模型的发展开辟了新的道路。本文将深入介绍 MOE 模型的原理、优势,并结合代码详细讲解如何从头训练一个 MOE 模型。
一、MOE模型简介
1.1 什么是 MOE 模型
混合专家模型(Mixture of Experts,MOE)是一种集成学习模型,它由多个专家网络(Expert Networks)和一个门控网络(Gating Network)组成。每个专家网络是一个独立的神经网络,专门用于处理特定类型的数据或任务。门控网络的作用是根据输入数据,动态地决定每个专家网络的权重,即哪些专家网络应该参与当前输入的处理,以及每个专家网络的贡献程度。具体来说,当输入数据进入 MOE 模型时,门控网络会计算每个专家网络的权重,然后将输入数据分别传递给这些专家网络进行处理,最后将各个专家网络的输出按照权重进行加权求和,得到最终的输出。
1.2 为什么需要 MOE 模型
1.提升模型表达能力
不同的专家网络可以学习到数据的不同特征和模式,通过将多个专家网络的输出进行组合,MOE 模型能够捕捉到更丰富、更复杂的信息,从而提升模型的表达能力。例如,在自然语言处理任务中,一个专家网络可能擅长处理语法结构,另一个专家网络可能擅长处理语义信息,通过 MOE 模型可以将这两种能力结合起来,提高模型在语言理解和生成任务中的表现。
2.降低计算成本
在传统的大型神经网络中,所有的神经元都会参与每一次的计算,这会导致计算量非常大。而在 MOE 模型中,门控网络会根据输入数据动态地选择部分专家网络进行计算,只有被选中的专家网络才会参与处理,从而减少了不必要的计算,降低了计算成本。特别是在处理大规模数据时,这种计算效率的提升尤为明显。
3.提高模型的可扩展性
MOE 模型可以很容易地扩展专家网络的数量,通过增加专家网络,可以进一步提升模型的表达能力和性能。而且,由于每个专家网络是独立的,可以并行计算,这使得 MOE 模型在分布式计算环境中具有很好的可扩展性。
二、代码及介绍
2.1 基础组件
2.1.1 归一化层(RMSNorm)
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
hidden_states = hidden_states.float()
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.float()
RMSNorm 是一种归一化方法,用于对输入的隐藏状态进行归一化处理。它通过计算输入的方差,并将输入除以方差的平方根,使得输入的分布更加稳定。这种归一化方法有助于加速模型的训练过程,提高模型的稳定性。
2.1.2 旋转位置嵌入(RotaryEmbedding)
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotate_pos_emb(q, k, cos, sin, unsqueeze_dim=2):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q*cos) + (rotate_half(q)*sin)
k_embed = (k*cos) + (rotate_half(k)*sin)
return q_embed, k_embed
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq_len=1024):
super(RotaryEmbedding, self).__init__()
self.dim = dim
self.max_seq_len = max_seq_len
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).float().unsqueeze(1)
freqs = t @ inv_freq.unsqueeze(0)
freqs = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", freqs.cos())
self.register_buffer("sin_cached", freqs.sin())
def forward(self, q, k):
cos = self.cos_cached[:q.shape[1], :].unsqueeze(0)
sin = self.sin_cached[:q.shape[1], :].unsqueeze(0)
return apply_rotate_pos_emb(q, k, cos, sin)
旋转位置嵌入用于为注意力机制中的查询(q)和键(k)向量添加位置信息。通过旋转操作,使得模型能够感知输入序列中不同位置的信息,从而更好地捕捉序列中的顺序和上下文关系。
2.1.3 注意力机制(Attention
class Attention(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.dropout = config.dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.k_cache, self.v_cache = None, None
self.is_causal = True
self.flash_attn = self.config.flash_attn
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
self.residual_dropout = nn.Dropout(self.dropout)
self.attention_dropout = nn.Dropout(self.dropout)
self.rotary_emb = RotaryEmbedding(self.head_dim)
def forward(self, hidden_states, use_kv_cache=False):
b, s = hidden_states.shape[:2]
if use_kv_cache and self.eval():
if self.k_cache is None or self.k_cache.shape[1] != s - 1:
q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
else:
token = hidden_states[:, -1:, :]
q = torch.cat((torch.zeros_like(hidden_states[:, :-1, :]), self.q_proj(token)), dim=1)
k = torch.cat((self.k_cache, self.k_proj(token)), dim=1)
v = torch.cat((self.v_cache, self.v_proj(token)), dim=1)
self.k_cache, self.v_cache = k, v
else:
q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
q = q.view(b, s, self.num_heads, self.head_dim)
k = k.view(b, s, self.num_key_value_heads, self.head_dim)
v = v.view(b, s, self.num_key_value_heads, self.head_dim)
q, k = self.rotary_emb(q, k)
k = repeat_kv(k, self.num_key_value_groups)
v = repeat_kv(v, self.num_key_value_groups)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if self.flash_attn:
output = F.scaled_dot_product_attention(q, k, v, attn_mask=None,
dropout_p=self.dropout if self.training else 0.0,
is_causal=self.is_causal)
else:
mask = torch.full((1, 1, self.config.max_seq_len, self.config.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
scores = scores + self.mask[:, :, :s, :s]
scores = F.softmax(scores.float(), dim=-1).type_as(q)
scores = self.attention_dropout(scores)
output = torch.matmul(scores, v)
output = output.transpose(1, 2).contiguous().view(b, s, -1)
output = self.o_proj(output)
output = self.residual_dropout(output)
return output
注意力机制是 Transformer 架构中的核心组件,用于计算输入序列中不同位置之间的关联程度。在 MOE 模型中,注意力机制帮助模型聚焦于输入序列中的重要信息,从而更好地进行特征提取和表示学习。
2.1.4 多层感知机(MLP)
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
def forward(self, x):
down_proj = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
return down_proj
多层感知机用于对输入进行非线性变换,增加模型的表达能力。在 MOE 模型中,MLP 可以作为专家网络的一部分,用于处理特定类型的数据。
2.2 MOE 核心组件
2.2.1 负载均衡损失函数(load_balancing_loss_func)
def load_balancing_loss_func(
gate_logits,
num_experts,
top_k):
concatenated_gate_logits = torch.cat([layer_gate for layer_gate in gate_logits], dim=0)
routing_weights = F.softmax(concatenated_gate_logits, dim=-1)
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
router_prob_per_expert = torch.mean(routing_weights, dim=0)
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
return overall_loss * num_experts
作用:该函数的主要目的是确保在训练过程中各个专家网络能够均衡地处理输入数据,避免出现某些专家网络过度使用,而另一些专家网络闲置的情况。这有助于提高模型整体的稳定性和效率,充分发挥每个专家网络的作用。
计算步骤:
(1)合并门控逻辑值:concatenated_gate_logits = torch.cat([layer_gate for layer_gate in gate_logits], dim=0) 将来自不同层的门控逻辑值(gate_logits)在维度 0 上进行拼接。这样做是因为我们要综合考虑所有层的门控信息,以全面评估专家网络的使用情况。
(2)计算路由权重:routing_weights = F.softmax(concatenated_gate_logits, dim=-1) 使用 softmax 函数将门控逻辑值转换为概率分布,得到每个输入对应各个专家网络的路由权重。这些权重表示每个专家网络对于处理当前输入的相对重要性。
(3)选择顶级专家:_, selected_experts = torch.topk(routing_weights, top_k, dim=-1) 通过 topk 操作,选择每个输入对应的概率最高的 top_k 个专家网络。这里的 selected_experts 记录了每个输入选择的专家网络索引。
(4)生成专家掩码:expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)使用 one_hot 编码将选择的专家网络索引转换为掩码矩阵。这个掩码矩阵可以直观地表示每个输入选择了哪些专家网络,其中 num_experts 是专家网络的总数。
(5)计算专家使用频率:tokens_per_expert = torch.mean(expert_mask.float(), dim=0) 计算每个专家网络被选择的平均频率。通过对掩码矩阵在维度 0 上求平均,得到每个专家网络被选中的概率,反映了每个专家网络处理的输入数量占总输入数量的比例。
(6)计算平均路由概率:router_prob_per_expert = torch.mean(routing_weights, dim=0) 计算每个专家网络的平均路由概率。通过对路由权重在维度 0 上求平均,得到每个专家网络在所有输入上的平均权重,反映了每个专家网络在处理输入时的平均重要性。
(7)计算整体损失:overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) 将专家使用频率与平均路由概率相乘并求和,得到整体的负载均衡损失。这一步的意义在于,综合考虑了专家网络的使用频率和其在处理输入时的重要性,以衡量专家网络负载的不均衡程度。最后乘以 num_experts 是为了适当缩放损失值,使其在训练过程中能够有效地影响模型的优化。
2.2.2 门控网络(Gating)
class Gating(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.topk = config.topk
self.expert_num = config.expert_num
self.gate = nn.Linear(self.hidden_size, self.expert_num)
def forward(self, x):
logits = self.gate(x)
logits_topk, indices = logits.topk(self.topk, dim=-1)
zeros = torch.full_like(logits, float("-inf"))
sparse_logits = zeros.scatter(dim=-1, index=indices, src=logits_topk)
sparse_logits = F.softmax(sparse_logits, dim=-1)
gate_logit = logits.view(-1, self.expert_num)
return sparse_logits, indices, gate_logit
作用:门控网络的主要功能是根据输入数据动态地决定哪些专家网络应该参与当前输入的处理,以及每个专家网络的贡献程度。它为 MOE 模型提供了一种自适应的路由机制,使得模型能够根据输入的特点灵活地分配计算资源。
计算步骤:
(1)线性变换:logits = self.gate(x) 通过一个线性层 self.gate 将输入 x 从 hidden_size 维度映射到 expert_num 维度,得到每个专家网络对应的原始分数(逻辑值)。这些分数表示输入与每个专家网络的匹配程度。
(2)选择顶级专家:logits_topk, indices = logits.topk(self.topk, dim=-1) 从原始分数中选择概率最大的 top_k 个专家网络。logits_topk 是这 top_k 个专家网络对应的分数,indices 则记录了这些专家网络的索引。
(3)生成稀疏矩阵:
zeros = torch.full_like(logits, float("-inf")) 创建一个与 logits 形状相同的全负无穷矩阵。
sparse_logits = zeros.scatter(dim=-1, index=indices, src=logits_topk) 使用 scatter函数将 logits_topk 中的值按 indices 所指定的位置填充到全负无穷矩阵中,从而得到一个稀疏矩阵。这个稀疏矩阵只保留了概率最大的 top_k 个专家网络的分数,其他位置为负无穷。
sparse_logits = F.softmax(sparse_logits, dim=-1) 对稀疏矩阵进行 softmax 操作,将其转换为概率分布。这样得到的概率分布表示每个输入对应的 top_k 个专家网络的相对权重,即每个专家网络在处理当前输入时的贡献比例。
(4)调整输出形状:gate_logit = logits.view(-1, self.expert_num) 将原始的门控逻辑值 logits 调整形状,以便后续计算负载均衡损失等操作。最后返回稀疏概率矩阵 sparse_logits、选择的专家网络索引 indices 和调整形状后的门控逻辑值 gate_logit。
2.2.3 专家网络(Expert)
class Expert(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
def forward(self, x):
down_proj = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
return down_proj
作用:专家网络是 MOE 模型中负责处理特定类型数据或任务的独立神经网络。每个专家网络都有自己独立的参数,能够学习到不同的数据特征和模式,从而为模型提供多样化的处理能力。通过门控网络的选择,不同的专家网络可以针对不同的输入发挥作用,共同提升模型的整体性能。
计算步骤:
(1)中间层变换
self.gate_proj(x) 和 self.up_proj(x) 分别通过线性层将输入 x 从 hidden_size 维度映射到 intermediate_size 维度。这里使用两个不同的线性层进行映射,是为了在中间层引入更多的参数和非线性变换能力,增强专家网络的表达能力。
F.silu(self.gate_proj(x)) 使用 silu 激活函数对 self.gate_proj(x) 的输出进行非线性变换。silu 函数(Sigmoid Linear Unit)可以增加模型的非线性表达能力,使专家网络能够学习到更复杂的特征。
F.silu(self.gate_proj(x)) * self.up_proj(x) 将经过 silu 激活后的结果与 self.up_proj(x) 的输出相乘。这种逐元素相乘的操作可以对特征进行进一步的筛选和组合,挖掘输入数据中不同特征之间的相互关系。
(2)输出层变换:self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) 使用另一个线性层 self.down_proj 将中间层处理后的结果从 intermediate_size 维度映射回 hidden_size 维度,得到专家网络的最终输出。这个输出包含了专家网络对输入数据的特征提取和处理结果,将作为 MOE 模型最终输出的一部分。
2.2.4 混合专家模型(MoE)
class MoE(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.experts = nn.ModuleList([Expert(config) for _ in range(config.expert_num)])
self.gating = Gating(config)
def forward(self, x):
sparse_logits, indices, gate_logit = self.gating(x)
final_outputs = torch.zeros_like(x)
x_flat = x.view(-1, x.shape[-1])
sparse_logits_flat = sparse_logits.view(-1, sparse_logits.shape[-1])
for i, expert in enumerate(self.experts):
expert_mask = (indices == i).any(-1)
expert_mask_flat = expert_mask.view(-1)
if expert_mask_flat.any():
expert_input = x_flat[expert_mask_flat]
export_output = expert(expert_input)
gate_scores = sparse_logits_flat[expert_mask_flat, i].unsqueeze(1)
weighted_output = export_output * gate_scores
final_outputs[expert_mask] += weighted_output
return final_outputs, gate_logit
作用:MoE 模型将多个专家网络和门控网络组合在一起,实现了根据输入数据动态选择和组合专家网络的功能。它通过门控网络为每个输入分配最合适的专家网络,并将专家网络的输出进行加权求和,从而得到最终的输出。这种结构使得模型能够充分利用各个专家网络的优势,提高模型的表达能力和适应性。
计算步骤:
(1)门控计算:sparse_logits, indices, gate_logit = self.gating(x) 首先将输入 x 传入门控网络 self.gating,得到每个输入对应的稀疏概率矩阵 sparse_logits(表示选择的专家网络的权重)、选择的专家网络索引 indices 和门控逻辑值 gate_logit。
(2)初始化输出:final_outputs = torch.zeros_like(x) 创建一个与输入 x 形状相同的全零张量 final_outputs,用于存储最终的输出结果。
(3)调整形状:
x_flat = x.view(-1, x.shape[-1]) 将输入 x 展平为二维张量,方便后续根据专家选择进行切片操作。
sparse_logits_flat = sparse_logits.view(-1, sparse_logits.shape[-1]) 将稀疏概率矩阵 sparse_logits 也展平为二维张量,以便与展平后的输入相对应。
(4)专家网络处理与加权求和:
遍历所有专家网络 for i, expert in enumerate(self.experts):。
expert_mask = (indices == i).any(-1) 创建一个掩码,用于判断每个输入是否选择了当前专家网络 i。(indices == i) 比较每个输入选择的专家网络索引是否等于当前专家网络的索引 i,any(-1) 则在最后一个维度上检查是否有任何元素为真。
expert_mask_flat = expert_mask.view(-1) 将掩码展平为一维张量,以便与展平后的输入和稀疏概率矩阵进行对应。
如果当前专家网络被至少一个输入选择 if expert_mask_flat.any()::
1.expert_input = x_flat[expert_mask_flat] 从展平后的输入中提取选择当前专家网络的输入部分。
2.export_output = expert(expert_input) 将这些输入传递给当前专家网络 expert 进行处理,得到专家网络的输出。
3.gate_scores = sparse_logits_flat[expert_mask_flat, i].unsqueeze(1) 从展平后的稀疏概率矩阵中提取当前专家网络对应的权重,并增加一个维度,使其与专家网络的输出维度匹配。
4.weighted_output = export_output * gate_scores 将专家网络的输出乘以对应的权重,得到加权后的输出。
5.final_outputs[expert_mask] += weighted_output 将加权后的输出累加到最终输出张量 final_outputs 中对应的位置。
(5)返回结果
最后返回最终的输出 final_outputs 和门控逻辑值 gate_logit,其中门控逻辑值可用于后续计算负载均衡损失等操作。
2.3 模型其他组件
2.3.1 解码器层(DecoderLayer)
class DecoderLayer(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Attention(config)
self.moe = MoE(config)
self.mlp = MLP(config)
self.input_layernorm = RMSNorm(config.hidden_size)
self.post_attention_layernorm = RMSNorm(config.hidden_size)
self.layer_idx = layer_idx
def forward(
self,
hidden_states,
use_kv_cache
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
use_kv_cache=use_kv_cache
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
if self.layer_idx % 2 == 0:
hidden_states = self.mlp(hidden_states)
gate_logit = None
else:
hidden_states, gate_logit = self.moe(hidden_states)
outputs = residual + hidden_states
return outputs, gate_logit
解码器层是 MOE 模型中的一个基本单元,它包含了注意力机制、MOE 模块和 MLP 模块。在正向传播过程中,输入数据首先通过注意力机制进行特征提取,然后根据层索引的奇偶性选择使用 MOE 模块或 MLP 模块进行进一步处理,最后将处理结果与残差连接相加,得到最终的输出。
2.3.2 模型配置(Config)
class Config(PretrainedConfig):
model_type = "moe_model"
def __init__(self,
hidden_size = 512,
num_attention_heads = 16,
num_key_value_heads = 8,
flash_attn = True,
attention_bias = False,
max_seq_len = 512,
intermediate_size = 2048,
mlp_bias = False,
vocab_size = 6400,
n_layers = 8,
dropout = 0.0,
expert_num = 4,
topk = 2,
output_router_logits = True,
aux_loss_coef = 0.01,
**kwargs):
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.flash_attn = flash_attn
self.attention_bias = attention_bias
self.max_seq_len = max_seq_len
self.intermediate_size = intermediate_size
self.mlp_bias = mlp_bias
self.vocab_size = vocab_size
self.n_layers = n_layers
self.dropout = dropout
self.expert_num = expert_num
self.topk = topk
self.output_router_logits = output_router_logits
self.aux_loss_coef = aux_loss_coef
super().__init__(**kwargs)
模型配置类 Config 继承自 PretrainedConfig,用于存储和管理模型的各种超参数。这些参数决定了模型的结构和训练过程,例如隐藏层大小、注意力头数量、专家网络数量等。通过配置类,可以方便地调整模型的参数,以适应不同的任务和数据集。
2.3.3 大语言模型(LLM)
class LLM(PreTrainedModel):
config_class = Config
def __init__(self, config):
super().__init__(config)
self.config = config
self.vocab_size = self.config.vocab_size
self.n_layers = self.config.n_layers
self.expert_num = self.config.expert_num
self.topk = self.config.topk
self.tokon_embeddings = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
self.dropout = nn.Dropout(self.config.dropout)
self.layers = torch.nn.ModuleList()
for layer_idx in range(self.n_layers):
self.layers.append(DecoderLayer(self.config, layer_idx))
self.norm = RMSNorm(self.config.hidden_size)
self.output = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
self.tokon_embeddings.weight = self.output.weight
self.apply(self._init_weights)
self.loss = None
self.aux_loss = None
for pn, p in self.named_parameters():
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layers))
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, input_ids, labels, use_kv_cache=False):
all_router_logits = () if self.config.output_router_logits else None
hidden_states = self.tokon_embeddings(input_ids)
hidden_states = self.dropout(hidden_states)
for idx, layer in enumerate(self.layers):
hidden_states, gate_logit = layer(hidden_states, use_kv_cache=use_kv_cache)
if gate_logit is not None:
all_router_logits += (gate_logit, )
hidden_states = self.norm(hidden_states)
if labels is not None:
logits = self.output(hidden_states)
self.loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=0)
else:
logits = self.output(hidden_states[:, [-1], :])
self.loss = None
if self.config.output_router_logits:
self.aux_loss = load_balancing_loss_func(all_router_logits, self.expert_num, self.topk)
if labels is not None:
self.loss += self.config.aux_loss_coef * self.aux_loss.to(self.loss.device)
return CausalLMOutputWithPast(self.loss, logits)
@torch.inference_mode
def generate(self, inputs, eos, max_new_tokens, temperature=0.7, top_k=None, stream=True, repetition_penalty=1.,
use_kv_cache=True):
input_ids = inputs['input_ids']
labels = inputs['labels']
s = input_ids.shape[1]
while input_ids.shape[1] < max_new_tokens - 1:
inference_res = self(input_ids, labels, use_kv_cache=use_kv_cache)
logits = inference_res.logits
logits = logits[:, -1, :]
for token in set(input_ids.tolist()[0]):
logits[:, token] /= repetition_penalty
if temperature == 0.0:
_, idx_next = torch.topk(logits, k=1, dim=-1)
else:
logits = logits / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1, generator=None)
if idx_next == eos:
break
input_ids = torch.cat((input_ids, idx_next), dim=1)
if stream:
yield input_ids[:, s:]
if not stream:
yield input_ids[:, s:]
LLM 类是整个模型的核心,它继承自 PreTrainedModel。在初始化时,模型会创建词嵌入层、多个解码器层、归一化层和输出层。
正向传播:输入的 input_ids 首先通过词嵌入层转换为向量表示,然后依次经过各个解码器层进行处理,每个解码器层可能使用 MOE 模块或 MLP 模块(索引的奇偶性选择使用 MOE 模块或 MLP 模块进行进一步处理)。处理完成后,经过归一化层,再通过输出层得到预测的 logits。如果提供了 labels,则计算交叉熵损失。同时,如果配置中要求输出路由 logits,还会计算负载均衡辅助损失,并将其添加到主损失中。generate 方法用于模型的推理阶段,根据输入生成新的文本。
反向传播:根据计算得到的损失,使用优化器(在 Trainer 内部管理)进行反向传播,更新模型的参数。优化器会根据学习率和损失函数的梯度来调整模型的权重,使得损失逐渐减小。
2.3.4 主程序
if __name__ == '__main__':
config = Config()
model = LLM(config)
print(f'模型参数量为:{sum(p.numel() for p in model.parameters() if p.requires_grad)}')
data_collator = DefaultDataCollator()
tokenizer = AutoTokenizer.from_pretrained("./tokenizer", use_fast=True)
args = TrainingArguments(output_dir='./moe',
num_train_epochs=10,
do_train=True,
per_device_train_batch_size=2,
gradient_accumulation_steps=1,
# max_steps=15000,
logging_steps=1,
report_to='tensorboard',
save_total_limit=5,
bf16=True,
learning_rate=2e-4,
lr_scheduler_type='cosine',
dataloader_num_workers=8,
dataloader_pin_memory=True,
save_safetensors=False)
dataset = LLMDataset('./train.jsonl', tokenizer=tokenizer, max_seq_len=512)
trainer = Trainer(model=model, args=args, train_dataset=dataset, tokenizer=tokenizer, data_collator=data_collator)
# 如果是初次训练resume_from_checkpoint为false,接着checkpoint继续训练,为True
trainer.train(resume_from_checkpoint=False)
trainer.save_model('./saves/moe')
trainer.save_state()
主程序部分首先创建模型配置对象 config,然后基于该配置初始化 LLM 模型。接着,打印模型的可训练参数量。之后,创建数据整理器、加载分词器、设置训练参数、加载数据集,并使用 Trainer 类进行模型训练。训练完成后,保存模型和训练状态。
2.4 总体代码
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
import os
import pandas as pd
from torch.utils.data import IterableDataset, Dataset
import json
import numpy as np
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import PretrainedConfig
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator, DataCollatorForTokenClassification, AutoConfig
from process_data import SFTDataset, LLMDataset
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
hidden_states = hidden_states.float()
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.float()
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotate_pos_emb(q, k, cos, sin, unsqueeze_dim=2):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q*cos) + (rotate_half(q)*sin)
k_embed = (k*cos) + (rotate_half(k)*sin)
return q_embed, k_embed
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq_len=1024):
super(RotaryEmbedding, self).__init__()
self.dim = dim
self.max_seq_len = max_seq_len
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).float().unsqueeze(1)
freqs = t @ inv_freq.unsqueeze(0)
freqs = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", freqs.cos())
self.register_buffer("sin_cached", freqs.sin())
def forward(self, q, k):
cos = self.cos_cached[:q.shape[1], :].unsqueeze(0)
sin = self.sin_cached[:q.shape[1], :].unsqueeze(0)
return apply_rotate_pos_emb(q, k, cos, sin)
def repeat_kv(hidden_states, n_rep):
batch, slen, num_key_value_heads, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, :, None, :].expand(batch, slen, num_key_value_heads, n_rep, head_dim)
return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, head_dim)
class Attention(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.dropout = config.dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.k_cache, self.v_cache = None, None
self.is_causal = True
self.flash_attn = self.config.flash_attn
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
self.residual_dropout = nn.Dropout(self.dropout)
self.attention_dropout = nn.Dropout(self.dropout)
self.rotary_emb = RotaryEmbedding(self.head_dim)
def forward(self, hidden_states, use_kv_cache=False):
b, s = hidden_states.shape[:2]
if use_kv_cache and self.eval():
if self.k_cache is None or self.k_cache.shape[1] != s-1:
q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
else:
token = hidden_states[:, -1:, :]
q = torch.cat((torch.zeros_like(hidden_states[:, :-1, :]), self.q_proj(token)), dim=1)
k = torch.cat((self.k_cache, self.k_proj(token)), dim=1)
v = torch.cat((self.v_cache, self.v_proj(token)), dim=1)
self.k_cache, self.v_cache = k, v
else:
q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
q = q.view(b, s, self.num_heads, self.head_dim)
k = k.view(b, s, self.num_key_value_heads, self.head_dim)
v = v.view(b, s, self.num_key_value_heads, self.head_dim)
q, k = self.rotary_emb(q, k)
k = repeat_kv(k, self.num_key_value_groups)
v = repeat_kv(v, self.num_key_value_groups)
q = q.transpose(1, 2) # b, self.num_heads, s, self.head_dim
k = k.transpose(1, 2) # b, self.num_heads, s, self.head_dim
v = v.transpose(1, 2) # b, self.num_heads, s, self.head_dim
if self.flash_attn:
# q*k转置,(b, self.num_heads, s, self.head_dim)* (b, self.num_heads, self.head_dim,s) = (b, self.num_heads, s, s)
# q*k/sqrt(self.head_dim)*v (b, self.num_heads, s, s)* (b, self.num_heads, s, self.head_dim) = b, self.num_heads, s, self.head_dim
output = F.scaled_dot_product_attention(q, k, v, attn_mask=None,
dropout_p=self.dropout if self.training else 0.0,
is_causal=self.is_causal)
else:
mask = torch.full((1, 1, self.config.max_seq_len, self.config.max_seq_len), float("-inf")) # 初始化掩码
mask = torch.triu(mask, diagonal=1) # 生成上三角掩码
scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim) # 计算注意力分数
scores = scores + self.mask[:, :, :s, :s] # 应用掩码
scores = F.softmax(scores.float(), dim=-1).type_as(q) # 计算 softmax
scores = self.attention_dropout(scores) # 应用注意力 dropout
output = torch.matmul(scores, v) # 计算输出
output = output.transpose(1, 2).contiguous().view(b, s, -1) # b, s, self.hidden_size
output = self.o_proj(output)
output = self.residual_dropout(output)
return output
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
def forward(self, x):
down_proj = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def load_balancing_loss_func(
gate_logits,
num_experts,
top_k):
concatenated_gate_logits = torch.cat([layer_gate for layer_gate in gate_logits], dim=0) # 各个层的gate_logit进行合并[layers X batch_size X sequence_length, num_experts]
routing_weights = F.softmax(concatenated_gate_logits, dim=-1)
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
router_prob_per_expert = torch.mean(routing_weights, dim=0)
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
return overall_loss * num_experts
class Gating(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.topk = config.topk
self.expert_num = config.expert_num
self.gate = nn.Linear(self.hidden_size, self.expert_num)
def forward(self, x):
# x dim: b, s, hidden_size
logits = self.gate(x) # gate: b, s, expert_num
logits_topk, indices = logits.topk(self.topk, dim=-1) # 选择概率最大的两个专家,返回两个专家对每个token的概率
zeros = torch.full_like(logits, float("-inf")) # 创建一个全为负无穷的矩阵,用于屏蔽其他专家的概率并重新归一化概率最大的两个专家
sparse_logits = zeros.scatter(dim=-1, index=indices, src=logits_topk) # 将选择的两个专家的概率按指定索引填充
sparse_logits = F.softmax(sparse_logits, dim=-1) # 得到一个稀疏矩阵,选择的两个专家对每个token的概率和为1
gate_logit = logits.view(-1, self.expert_num)
return sparse_logits, indices, gate_logit
class Expert(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
def forward(self, x):
down_proj = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class MoE(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.experts = nn.ModuleList([Expert(config) for _ in range(config.expert_num)])
self.gating = Gating(config)
def forward(self, x):
sparse_logits, indices, gate_logit = self.gating(x)#输入经过门控路由,得到选择的专家的索引和专家的概率
final_outputs = torch.zeros_like(x)
x_flat = x.view(-1, x.shape[-1]) # (batch_size * seq_len, dim)
sparse_logits_flat = sparse_logits.view(-1, sparse_logits.shape[-1]) # (batch_size * seq_len, export_num))
#遍历所有的专家,判断当前的专家是否是输入token所选择的专家
for i, expert in enumerate(self.experts):
expert_mask = (indices == i).any(-1) # (batch_size, seq_len)
expert_mask_flat = expert_mask.view(-1) # (batch_size * seq_len)
if expert_mask_flat.any():
expert_input = x_flat[expert_mask_flat] # (seq_true, dim)
export_output = expert(expert_input) # (seq_true, dim)
gate_scores = sparse_logits_flat[expert_mask_flat, i].unsqueeze(1) # (seq_true) --> (seq_true, 1)
weighted_output = export_output * gate_scores # (seq_true, dim)
final_outputs[expert_mask] += weighted_output
return final_outputs, gate_logit
class DecoderLayer(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Attention(config)
self.moe = MoE(config)
self.mlp = MLP(config)
self.input_layernorm = RMSNorm(config.hidden_size)
self.post_attention_layernorm = RMSNorm(config.hidden_size)
self.layer_idx = layer_idx
def forward(
self,
hidden_states,
use_kv_cache
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states=hidden_states,
use_kv_cache=use_kv_cache
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
if self.layer_idx % 2 == 0:
hidden_states = self.mlp(hidden_states)
gate_logit = None
else:
hidden_states, gate_logit = self.moe(hidden_states)
outputs = residual + hidden_states
return outputs, gate_logit
# 编写自定义配置时需要记住的三个重要事项如下:
# 1、必须继承自 PretrainedConfig
# 2、PretrainedConfig 的 __init__ 方法必须接受任何 kwargs
# 3、这些 kwargs 需要传递给超类的 __init__ 方法。
class Config(PretrainedConfig):
model_type = "moe_model"
def __init__(self,
hidden_size = 512,
num_attention_heads = 16,
num_key_value_heads = 8,
flash_attn = True,
attention_bias = False,
max_seq_len = 512,
intermediate_size = 2048,
mlp_bias = False,
vocab_size = 6400,
n_layers = 8,
dropout = 0.0,
expert_num = 4,
topk = 2,
output_router_logits = True,
aux_loss_coef = 0.01,
**kwargs):
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.flash_attn = flash_attn
self.attention_bias = attention_bias
self.max_seq_len = max_seq_len
self.intermediate_size = intermediate_size
self.mlp_bias = mlp_bias
self.vocab_size = vocab_size
self.n_layers = n_layers
self.dropout = dropout
self.expert_num = expert_num
self.topk = topk
self.output_router_logits = output_router_logits
self.aux_loss_coef = aux_loss_coef
super().__init__(**kwargs)
class LLM(PreTrainedModel):
config_class = Config
def __init__(self, config):
super().__init__(config)
self.config = config
self.vocab_size = self.config.vocab_size
self.n_layers = self.config.n_layers
self.expert_num = self.config.expert_num
self.topk = self.config.topk
self.tokon_embeddings = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
self.dropout = nn.Dropout(self.config.dropout)
self.layers = torch.nn.ModuleList()
for layer_idx in range(self.n_layers):
self.layers.append(DecoderLayer(self.config, layer_idx))
self.norm = RMSNorm(self.config.hidden_size)
self.output = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
self.tokon_embeddings.weight = self.output.weight
self.apply(self._init_weights)
self.loss = None
self.aux_loss = None
for pn, p in self.named_parameters():
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layers))
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, input_ids, labels, use_kv_cache=False):
all_router_logits = () if self.config.output_router_logits else None
hidden_states = self.tokon_embeddings(input_ids)
hidden_states = self.dropout(hidden_states)
for idx, layer in enumerate(self.layers):
hidden_states, gate_logit = layer(hidden_states, use_kv_cache=use_kv_cache)
if gate_logit is not None:
all_router_logits += (gate_logit, )
hidden_states = self.norm(hidden_states)
if labels is not None:
logits = self.output(hidden_states)
self.loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=0)
else:
logits = self.output(hidden_states[:, [-1], :])
self.loss = None
if self.config.output_router_logits:
self.aux_loss = load_balancing_loss_func(all_router_logits, self.expert_num, self.topk)
if labels is not None:
self.loss += self.config.aux_loss_coef * self.aux_loss.to(self.loss.device)
return CausalLMOutputWithPast(self.loss, logits)
@torch.inference_mode
def generate(self, inputs, eos, max_new_tokens, temperature=0.7, top_k=None, stream=True, repetition_penalty=1.,
use_kv_cache=True):
input_ids = inputs['input_ids']
labels = inputs['labels']
s = input_ids.shape[1]
while input_ids.shape[1] < max_new_tokens - 1:
inference_res = self(input_ids, labels, use_kv_cache=use_kv_cache)
logits = inference_res.logits
logits = logits[:, -1, :]
for token in set(input_ids.tolist()[0]):
logits[:, token] /= repetition_penalty
if temperature == 0.0:
_, idx_next = torch.topk(logits, k=1, dim=-1)
else:
logits = logits / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1, generator=None)
if idx_next == eos:
break
input_ids = torch.cat((input_ids, idx_next), dim=1)
if stream:
yield input_ids[:, s:]
if not stream:
yield input_ids[:, s:]
if __name__ == '__main__':
config = Config()
model = LLM(config)
print(f'模型参数量为:{sum(p.numel() for p in model.parameters() if p.requires_grad)}')
data_collator = DefaultDataCollator()
tokenizer = AutoTokenizer.from_pretrained("./tokenizer", use_fast=True)
args = TrainingArguments(output_dir='./moe',
num_train_epochs=10,
do_train=True,
per_device_train_batch_size=2,
gradient_accumulation_steps=1,
# max_steps=15000,
logging_steps=1,
report_to='tensorboard',
save_total_limit=5,
bf16=True,
learning_rate=2e-4,
lr_scheduler_type='cosine',
dataloader_num_workers=8,
dataloader_pin_memory=True,
save_safetensors=False)
dataset = LLMDataset('./train.jsonl', tokenizer=tokenizer, max_seq_len=512)
trainer = Trainer(model=model, args=args, train_dataset=dataset, tokenizer=tokenizer, data_collator=data_collator)
# 如果是初次训练resume_from_checkpoint为false,接着checkpoint继续训练,为True
trainer.train(resume_from_checkpoint=False)
trainer.save_model('./saves/moe')
trainer.save_state()
'''
/data/shared/Qwen/Codes/
run_vllm.sh
python -m vllm.entrypoints.openai.api_server \
--trust-remote-code \
--model /mnt/model \
--gpu-memory-utilization 0.9 \
--tensor-parallel-size 1 \
--max-model-len 4096 \
--port 19518 \
--quantization gptq \
--dtype float16 \
--max-num-batched-tokens 4096 \
--max-num-seqs 32
总结
本文详细介绍了混合专家模型(MOE)的原理、优势以及如何从头开始训练一个 MOE 模型。MOE 模型通过引入多个专家网络和门控机制,在提升模型表达能力的同时,有效降低了计算成本,提高了模型的可扩展性。代码部分涵盖了模型的各个组件,包括归一化层、旋转位置嵌入、注意力机制、多层感知机、门控网络、专家网络等,以及模型的整体结构和训练过程。通过调整模型配置参数,可以根据不同的任务和数据集进行定制化训练。在实际应用中,MOE 模型可以广泛应用于自然语言处理、计算机视觉等领域,为解决复杂的任务提供了一种有效的方法。同时,负载均衡损失函数的引入有助于确保每个专家网络都能得到充分利用,避免资源浪费。未来,随着技术的不断发展,MOE 模型有望在大模型领域发挥更加重要的作用。
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐
所有评论(0)