import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        """层归一化构造函数

        Args:
            dim (int): 层归一化的输入
            eps (float, optional): 偏置值. Defaults to 1e-6.
        """
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
        
    def _norm(self, x):
        """归一化

        Args:
            x (_type_): 输入
        """
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim = True) + self.eps)
    
    def forward(self, x):
        """归一化层的前向传播

        Args:
            x (_type_): 输入
        Details:
            self.weight * self._norm:广播broadcast操作,输出的shape跟_norm后的shape一样,self.weight是可学习的
            type_as(x):对归一化后的高精度 x 转为原来的类型 float16
        """
        return self.weight * self._norm(x.float()).type_as(x)

解释:
在这里插入图片描述
在这里插入图片描述

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐