transformer中的前馈网络(Feed Forward)

  • Feed Forward存在于Encode与Decode模块中
  • Fedd Forward只是简单的全连接神经网络,维度 d f f d_{ff} dff为2048,使用ReLU激活函数
    F F N ( x ) = m a x ( 0 , x W 1 + b 1 ) W 2 + b 2 FFN(x) = max(0, xW_1+b_1)W_2 + b_2 FFN(x)=max(0,xW1+b1)W2+b2

代码实现

前面实现的embedding,positional encoding,MultiHeadAttention与Add&Norm代码。

import torch.nn as nn
import torch.nn.functional as F
import torch
import math
vocab_size = 10
d_model = 4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class MyEmbedding(nn.Embedding): 
    def __init__(self, num_embeddings, embedding_dim, device):
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.device = device
        super().__init__(self.num_embeddings, self.embedding_dim, device=device)
    def forward(self, input_ids): 
        return super().forward(input_ids) * torch.sqrt(torch.tensor(self.embedding_dim).to(device))


class MyPositonalEncoding(nn.Module): 
    def __init__(self, seq_length, d_model, device): 
        if d_model % 2: 
            raise ValueError("embedding_dim must be an even number for positional encoding.")
        super().__init__()
        self.seq_length = seq_length
        self.d_model = d_model
        self.device = device
        pe = torch.zeros(self.seq_length, self.d_model)
        pos = torch.arange(0, self.seq_length, dtype=torch.float).unsqueeze(1)
        # div_term = 1 / (1000 ** (torch.arange(0, d_model, 2).float() / d_model)), 存在精度与性能问题,参考pytorch使用下面方式
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        # (seq_length, d_model) -> (seq_length, 1, d_model)
        pe = pe.unsqueeze(0).transpose(0, 1)
        # 与self.pe = pe不同在于:会被持久化保存,不参与梯度学习
        self.register_buffer("pe", pe)
    def forward(self, x): 
        # x的形状为(seq_length, batch_size, d_model)
        seq_length = x.shape[0]
        return x + self.pe[:seq_length, :, :].to(device)

class MyAttention(nn.Module): 
    def __init__(self): 
        super().__init__()
    def forward(self, q, k, v, mask=None): 
        seq_len, batch_size, num_head, d_k = q.shape
        _q = q.permute(1, 2, 0, 3) 
        _k = k.permute(1, 2, 3, 0)
        score = torch.matmul(_q, _k) / math.sqrt(d_k) # score形状为 (batch_size, num_head, seq_len, seq_len)
        if mask is not None: 
            score = score.masked_fill(mask == 0, float('-inf'))
        attention_score = F.softmax(score, -1)
        _v = v.permute(1, 2, 0, 3)
        output = torch.matmul(attention_score, _v) #output形状为(batch_size, num_head, seq_len, d_v)
        output = output.permute(2, 0, 1, 3)        #output形状为(seq_len, batch_size, num_head, d_v)
        return output
        
        
class MyMultiHeadAttention(nn.Module): 
    def __init__(self, num_head, d_model, device): 
        super().__init__()
        self.num_head = num_head
        self.d_model = d_model
        self.device = device
        assert d_model % num_head == 0
        self.d_k = d_model // num_head
        self.d_v = self.d_k
        # 这里采用策略为先统一进行线性变换,再切分给不同头
        self.q_proj = nn.Linear(d_model, d_model, bias=False, device=device)
        self.k_proj = nn.Linear(d_model, d_model, bias=False, device=device)
        self.v_proj = nn.Linear(d_model, d_model, bias=False, device=device)
        self.attention = MyAttention()
        self.o_proj = nn.Linear(d_model, d_model, bias=False, device=device)
    def forward(self, x, mask=None): 
        seq_len, batch_size, d_model = x.shape
        q = self.q_proj(x) 
        k = self.k_proj(x) 
        v = self.v_proj(x)
        # (seq_len, batch_size, d_model) -> (seq_len, batch_size, num_head, d_v)
        q = q.contiguous().view(seq_len, batch_size, self.num_head, self.d_k)
        k = k.contiguous().view(seq_len, batch_size, self.num_head, self.d_k)
        v = v.contiguous().view(seq_len, batch_size, self.num_head, self.d_v)
        atten_out = self.attention(q, k, v, mask)
        atten_out = atten_out.contiguous().view(seq_len, batch_size, -1)
        output = self.o_proj(atten_out)
        return output


class MyAddAndNorm(nn.Module): 
    def __init__(self, d_model, device): 
        super().__init__() 
        self.d_model = d_model
        self.device = device
        self.layer_norm = nn.LayerNorm(d_model, device=device)
    def forward(self, x, change_x): 
        return self.layer_norm(x + change_x)

实现Feed Forward类

class MyFeedForward(nn.Module): 
    def __init__(self, d_model, d_ff, device): 
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        self.device = device
        self.up_proj = nn.Linear(d_model, d_ff, device=device)
        self.down_proj = nn.Linear(d_ff, d_model, device=device)
        self.relu = nn.ReLU()
    def forward(self, x): 
        _x = self.relu(self.up_proj(x))
        return self.down_proj(_x)

测试代码

if_test = True
d_ff = 2 * d_model
embedding_layer = MyEmbedding(vocab_size, d_model, device)
positional_ecoding_layer = MyPositonalEncoding(vocab_size, d_model, device)
multi_head_attention = MyMultiHeadAttention(num_head=2, d_model=d_model, device=device)
add_and_norm = MyAddAndNorm(d_model=d_model, device=device)
feed_forward = MyFeedForward(d_model, d_ff, device)
if if_test: 
    token_ids = torch.tensor([[0, 1, 2], [2, 3, 4]], dtype=torch.long).to(device)
    embedding_ids = embedding_layer(token_ids)
    embedding_ids = embedding_ids.transpose(0, 1)
    pos_ids = positional_ecoding_layer(embedding_ids)
    print(pos_ids.shape)
    mask = torch.tril(torch.ones((3, 3)), diagonal = 0).unsqueeze(0).to(device) # 3为seq_len
    attention_ids = multi_head_attention(pos_ids)
    add_and_norm_addention_ids = add_and_norm(pos_ids, attention_ids)
    print(add_and_norm_addention_ids)
    print(add_and_norm_addention_ids.shape)
    output_ids = feed_forward(add_and_norm_addention_ids)
    print(output_ids)
    print(output_ids.shape)
Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐