解析transformer——5:Feed Forward,前馈网络
摘要 本文介绍了Transformer中的前馈网络(Feed Forward)实现。该网络存在于编码器和解码器模块中,是一个简单的全连接神经网络,维度为2048,使用ReLU激活函数。文章提供了完整的PyTorch实现代码,包括Embedding层、位置编码(Positional Encoding)、多头注意力机制(MultiHeadAttention)、Add&Norm层以及前馈网络(F
·
transformer中的前馈网络(Feed Forward)
- Feed Forward存在于Encode与Decode模块中
- Fedd Forward只是简单的全连接神经网络,维度 d f f d_{ff} dff为2048,使用ReLU激活函数
F F N ( x ) = m a x ( 0 , x W 1 + b 1 ) W 2 + b 2 FFN(x) = max(0, xW_1+b_1)W_2 + b_2 FFN(x)=max(0,xW1+b1)W2+b2
代码实现
前面实现的embedding,positional encoding,MultiHeadAttention与Add&Norm代码。
import torch.nn as nn
import torch.nn.functional as F
import torch
import math
vocab_size = 10
d_model = 4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class MyEmbedding(nn.Embedding):
def __init__(self, num_embeddings, embedding_dim, device):
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.device = device
super().__init__(self.num_embeddings, self.embedding_dim, device=device)
def forward(self, input_ids):
return super().forward(input_ids) * torch.sqrt(torch.tensor(self.embedding_dim).to(device))
class MyPositonalEncoding(nn.Module):
def __init__(self, seq_length, d_model, device):
if d_model % 2:
raise ValueError("embedding_dim must be an even number for positional encoding.")
super().__init__()
self.seq_length = seq_length
self.d_model = d_model
self.device = device
pe = torch.zeros(self.seq_length, self.d_model)
pos = torch.arange(0, self.seq_length, dtype=torch.float).unsqueeze(1)
# div_term = 1 / (1000 ** (torch.arange(0, d_model, 2).float() / d_model)), 存在精度与性能问题,参考pytorch使用下面方式
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(pos * div_term)
pe[:, 1::2] = torch.cos(pos * div_term)
# (seq_length, d_model) -> (seq_length, 1, d_model)
pe = pe.unsqueeze(0).transpose(0, 1)
# 与self.pe = pe不同在于:会被持久化保存,不参与梯度学习
self.register_buffer("pe", pe)
def forward(self, x):
# x的形状为(seq_length, batch_size, d_model)
seq_length = x.shape[0]
return x + self.pe[:seq_length, :, :].to(device)
class MyAttention(nn.Module):
def __init__(self):
super().__init__()
def forward(self, q, k, v, mask=None):
seq_len, batch_size, num_head, d_k = q.shape
_q = q.permute(1, 2, 0, 3)
_k = k.permute(1, 2, 3, 0)
score = torch.matmul(_q, _k) / math.sqrt(d_k) # score形状为 (batch_size, num_head, seq_len, seq_len)
if mask is not None:
score = score.masked_fill(mask == 0, float('-inf'))
attention_score = F.softmax(score, -1)
_v = v.permute(1, 2, 0, 3)
output = torch.matmul(attention_score, _v) #output形状为(batch_size, num_head, seq_len, d_v)
output = output.permute(2, 0, 1, 3) #output形状为(seq_len, batch_size, num_head, d_v)
return output
class MyMultiHeadAttention(nn.Module):
def __init__(self, num_head, d_model, device):
super().__init__()
self.num_head = num_head
self.d_model = d_model
self.device = device
assert d_model % num_head == 0
self.d_k = d_model // num_head
self.d_v = self.d_k
# 这里采用策略为先统一进行线性变换,再切分给不同头
self.q_proj = nn.Linear(d_model, d_model, bias=False, device=device)
self.k_proj = nn.Linear(d_model, d_model, bias=False, device=device)
self.v_proj = nn.Linear(d_model, d_model, bias=False, device=device)
self.attention = MyAttention()
self.o_proj = nn.Linear(d_model, d_model, bias=False, device=device)
def forward(self, x, mask=None):
seq_len, batch_size, d_model = x.shape
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
# (seq_len, batch_size, d_model) -> (seq_len, batch_size, num_head, d_v)
q = q.contiguous().view(seq_len, batch_size, self.num_head, self.d_k)
k = k.contiguous().view(seq_len, batch_size, self.num_head, self.d_k)
v = v.contiguous().view(seq_len, batch_size, self.num_head, self.d_v)
atten_out = self.attention(q, k, v, mask)
atten_out = atten_out.contiguous().view(seq_len, batch_size, -1)
output = self.o_proj(atten_out)
return output
class MyAddAndNorm(nn.Module):
def __init__(self, d_model, device):
super().__init__()
self.d_model = d_model
self.device = device
self.layer_norm = nn.LayerNorm(d_model, device=device)
def forward(self, x, change_x):
return self.layer_norm(x + change_x)
实现Feed Forward类
class MyFeedForward(nn.Module):
def __init__(self, d_model, d_ff, device):
super().__init__()
self.d_model = d_model
self.d_ff = d_ff
self.device = device
self.up_proj = nn.Linear(d_model, d_ff, device=device)
self.down_proj = nn.Linear(d_ff, d_model, device=device)
self.relu = nn.ReLU()
def forward(self, x):
_x = self.relu(self.up_proj(x))
return self.down_proj(_x)
测试代码
if_test = True
d_ff = 2 * d_model
embedding_layer = MyEmbedding(vocab_size, d_model, device)
positional_ecoding_layer = MyPositonalEncoding(vocab_size, d_model, device)
multi_head_attention = MyMultiHeadAttention(num_head=2, d_model=d_model, device=device)
add_and_norm = MyAddAndNorm(d_model=d_model, device=device)
feed_forward = MyFeedForward(d_model, d_ff, device)
if if_test:
token_ids = torch.tensor([[0, 1, 2], [2, 3, 4]], dtype=torch.long).to(device)
embedding_ids = embedding_layer(token_ids)
embedding_ids = embedding_ids.transpose(0, 1)
pos_ids = positional_ecoding_layer(embedding_ids)
print(pos_ids.shape)
mask = torch.tril(torch.ones((3, 3)), diagonal = 0).unsqueeze(0).to(device) # 3为seq_len
attention_ids = multi_head_attention(pos_ids)
add_and_norm_addention_ids = add_and_norm(pos_ids, attention_ids)
print(add_and_norm_addention_ids)
print(add_and_norm_addention_ids.shape)
output_ids = feed_forward(add_and_norm_addention_ids)
print(output_ids)
print(output_ids.shape)
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐
所有评论(0)