torch.nn.Embedding详解:为什么要用Embedding,Embedding怎么用,Embedding的底层源码
本文深入解析了深度学习中的Embedding技术,主要涵盖三方面内容:首先阐述了Embedding的价值,它能将离散词ID转换为可训练的连续向量,解决传统编码单一性和不可训练的问题;其次介绍了PyTorch中nn.Embedding的使用方法,包括参数设置和输入输出格式转换;最后从源码层面揭示了Embedding的底层实现机制,指出其本质是一个可训练的查找表(vocab_size×embeddin
目录
在多模态大模型盛行的当下,我们常常会看到词嵌入Embedding的身影,这篇文章就让我们来揭开Embedding神秘的面纱。
一、为什么要用Embedding
就以我们自然语言处理中的文本数据为例,来总地介绍下整体的文本预处理的流程和Embedding在其中起到的作用。
在我们读入数据集的时候,我们读入数据的格式往往是[batch_size, seq_len]的形式将数据读入的,batch_size指的是一次读入了多少句子(样本),seq_len指的是一个句子中分词的个数,这些分词都是由tokenizer对原始句子进行分词完后将词表示为一个序号,这个序号就是这个分词在本数据集的词典中所分配好的编号,具体格式参考下图:

在代码中的表示应为如下过程:
sentences = [["我","爱","你"],["你","好","啊"]]
|
|
|
|
\/
sentences = [[221,356,778],[778,498,519]]
接下来就是重头戏了,Eembbding登场,在每个句子中的分词被编码之后,传入模型之中会面临两个问题:1.用编号表示单个分词显得过于单一,模型很难去学习各个分词之间的更多细节;2.我们的分词表示是以常数形式进行表示的,无法参与进我们后续的模型训练当中,我们希望模型的输入表示也是一个可以被训练的对象,能够更好地契合我们的任务。
Eembedding就很好解决了这两个问题,首先,Embedding会将每个分词表示为embedding_dim长度的变量,即输入格式由[batch_size, seq_len]变为[batch_size, seq_len,embedding_dim]的变量,在单个分词上看就是(,1)——> (,embedding_dim),并且该变量还是一个可训练的parameter。
再补充一点,Embedding 层将输入的离散词或 token 映射为可学习的向量表示,与传统的 word2vec 通过独立训练生成静态词向量矩阵 的方式不同。Embedding 的参数会在整个神经网络的训练过程中与其他层一起更新,因此能够根据具体任务不断调整词向量,使输入表示更加贴合模型的目标和语义需求。
二、Embedding怎么用
在PyTorch中,`nn.Embedding`层是用于处理离散数据(如单词或类别)的关键组件,特别常见于自然语言处理(NLP)和推荐系统等任务。它的主要功能是将输入的整数索引映射到连续的高维向量空间中,即将索引转化为嵌入向量。
torch.nn.Embedding(vocab_size, embedding_dim)
- `vocab_size`: 嵌入表的大小,即词汇表的大小或类别数。它定义了有多少个不同的“离散输入”可以映射到嵌入向量。
- `embedding_dim`: 每个离散输入(类别、单词等)将被映射到的连续向量的维度大小。
`nn.Embedding`的输入通常是整数(类别索引或词汇索引),它会根据输入的索引从一个大小为 `(vocab_size, embedding_dim)` 的查找表中检索出相应的嵌入向量。
import torch
import torch.nn as nn
# 定义Embedding层
embedding = nn.Embedding(10, 3) # vocab_size=10, embedding_dim=3
# 输入索引
input_indices = torch.tensor([1, 2, 3])
# 获取嵌入向量
output = embedding(input_indices)
print(output)
运行结果:
tensor([[ 1.5522, 0.7179, 1.6805],
[ 2.1118, 0.2995, 0.4167],
[-0.6033, -0.4972, -1.6700]], grad_fn=<EmbeddingBackward0>)
- embedding层携带巨大的权重矩阵,是参数量计算的关键过程之一
print(embedding.weight) #结构为10,3
运行结果:
Parameter containing:
tensor([[ 0.4752, -0.2457, 0.2101],
[ 1.5522, 0.7179, 1.6805],
[ 2.1118, 0.2995, 0.4167],
[-0.6033, -0.4972, -1.6700],
[-0.8719, -0.7207, 0.8305],
[ 1.2962, -1.2880, 0.8838],
[-0.7804, -0.1872, 0.3502],
[-0.2817, -0.9322, 0.5499],
[-0.5277, 0.8808, -1.6055],
[ 0.5706, 0.9455, -0.0734]], requires_grad=True)
三、Embedding的底层源码
很多人用 nn.Embedding 都只是机械地记住“要传 vocab_size 和 embedding_dim”,但没真正理解背后的机制。我们来把它从源码层面和数学逻辑层面都讲清楚。
1)nn.Embedding 的核心思想
nn.Embedding 是一个查表层(lookup table layer),它的功能是把离散的词ID映射成一个连续的稠密向量。
一句话解释:
它本质上就是一个矩阵
weight,形状为[vocab_size, embedding_dim]。
当你输入一个 token id,比如42,它就取出第42行的向量。
2)基本用法
import torch
import torch.nn as nn
emb = nn.Embedding(num_embeddings=10000, embedding_dim=300)
input = torch.tensor([1, 3, 5])
output = emb(input)
print(output.shape) # torch.Size([3, 300])
这里:
-
num_embeddings= 10000 → 表示词表大小(vocab_size) -
embedding_dim= 300 → 每个词的向量维度
3)源码核心实现(简化版)
官方源码在 torch/nn/modules/sparse.py 中,大致如下(核心逻辑简化后):
class Embedding(nn.Module):
def __init__(self, num_embeddings, embedding_dim):
super(Embedding, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
# 初始化权重矩阵 (词向量表)
self.weight = nn.Parameter(torch.empty(num_embeddings, embedding_dim))
self.reset_parameters()
def forward(self, input):
return F.embedding(input, self.weight)
def embedding(input, weight):
# input: LongTensor containing indices (e.g. [1, 3, 5])
# weight: [vocab_size, embedding_dim]
return weight[input]
核心一句话是:
output = weight[input]
也就是说,embedding层就是在一个矩阵里按索引取行!
4)为什么输入维度要是 vocab_size
这是因为 embedding 层的权重矩阵的“行数”必须覆盖所有可能的 token id。
比如:
-
你的词表大小是 10000;
-
token id 的范围是
[0, 9999]; -
那 embedding 矩阵就必须有 10000 行。
否则——如果 id=9999 而矩阵只有 5000 行,你查表时就会越界,报错:
IndexError: index out of range in self
所以:
“输入维度 = vocab_size” 是为了让每个词(或ID)都有唯一对应的一行向量。
5)梯度与训练机制
虽然看似只是查表,但这层也是可训练的:
-
反向传播时,只会对被访问到的那几行计算梯度;
-
所以 embedding 层更新是稀疏的(sparse update);
-
这就是它高效的原因。
例如输入 [1, 3, 5],只有第 1、3、5 行的向量会被更新。
6)小结
| 参数 | 含义 | 对应矩阵维度 | 作用 |
|---|---|---|---|
num_embeddings |
词表大小 | 行数 | 每个词一个向量 |
embedding_dim |
向量维度 | 列数 | 向量的语义维度 |
| 输入(input) | 词ID序列 | [batch, seq_len] |
用来查表的索引 |
| 输出(output) | 词向量序列 | [batch, seq_len, embedding_dim] |
词向量结果 |
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐
所有评论(0)