大模型之-层归一化RMSNorm的实现
【代码】大模型之-层归一化RMSNorm的实现。
·
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""层归一化构造函数
Args:
dim (int): 层归一化的输入
eps (float, optional): 偏置值. Defaults to 1e-6.
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
"""归一化
Args:
x (_type_): 输入
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim = True) + self.eps)
def forward(self, x):
"""归一化层的前向传播
Args:
x (_type_): 输入
Details:
self.weight * self._norm:广播broadcast操作,输出的shape跟_norm后的shape一样,self.weight是可学习的
type_as(x):对归一化后的高精度 x 转为原来的类型 float16
"""
return self.weight * self._norm(x.float()).type_as(x)
解释:

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐
所有评论(0)