1 概念速览

术语 定义 关键点
编码器 (Encoder) 将输入序列 $\mathbf x={x_1\dots x_n}$ 压缩为 上下文表示 $\mathbf c$(向量或张量) 提炼关键信息,支持变长输入
解码器 (Decoder) 在 $\mathbf c$ 的条件下自回归地产生输出序列 $\mathbf y={y_1\dots y_m}$ 生成、翻译、预测等;可视为条件语言模型
训练目标 最大化对数似然 $\log p_\theta(\mathbf y\mid\mathbf x)$ 典型损失:交叉熵
为什么有效 原生支持「变长输入 → 变长输出」,并能通过注意力显式对齐 机器翻译、摘要等 Seq2Seq 任务的基础

2 网络形态与“同构”迷思

Encoder Decoder 场景示例 备注
RNN/LSTM/GRU RNN/LSTM/GRU 早期 NMT、时间序列预测 纵向依赖强,训练难度大
卷积 CNN 反卷积或 CNN U-Net 图像分割 本地感受野,建模全局需扩张卷积
Transformer Transformer 主流文本/多模态生成 并行化、长依赖;显存吃紧
Hybrid Hybrid 长序列、流式语音 ASR 编码器和解码器可异构

结论: 编码器与解码器完全可以使用不同类型网络。

  • RNN → Transformer:先压缩时序,再高效全局注意力解码
  • CNN → CTC 解码器:流式语音,低延迟
  • ViT → 文本 Transformer:图像字幕(BLIP-2)

3 注意力 & 上下文

  1. 固定上下文向量瓶颈
    早期 Seq2Seq 仅传递单向量 $\mathbf c$,长句信息易丢失。
  2. Bahdanau / Luong 注意力
    解码时对编码隐藏态打分,动态读取相关信息。
  3. Transformer
    编码器和解码器均以多头自注意力为核心,完全抛弃循环结构。
  4. 跨模态注意力
    视觉 token ↔ 字幕 token,或语音特征 ↔ 文本 token。

4 典型任务与落地框架

任务 输入 ➜ 输出 主流开源模型 / 库
机器翻译 句子 ➜ 句子 Transformer、mBART、MarianMT
文本摘要 长文 ➜ 简短摘要 BART、Pegasus、T5
对话生成 历史对话 ➜ 回复 DialoGPT、LLaMA-Chat
语音识别 声谱图 ➜ 文本 Whisper、RNN-T
图像字幕 图像特征 ➜ 文字 BLIP-2、PaLI
时间序列预测 历史序列 ➜ 未来序列 Informer、Seq2Seq RNN
代码补全 代码上下文 ➜ 续写 CodeT5、StarCoder

5 极简 PyTorch 模板

import torch, torch.nn as nn
from random import random

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, sos_id, eos_id, max_len=128):
        super().__init__()
        self.encoder, self.decoder = encoder, decoder
        self.sos, self.eos, self.max_len = sos_id, eos_id, max_len

    def forward(self, src_ids, tgt_ids=None, teacher_forcing=0.5):
        # ① 编码
        memory = self.encoder(src_ids)

        # ② 解码(训练或推理)
        B = src_ids.size(0)
        ys = torch.full((B, 1), self.sos, dtype=torch.long, device=src_ids.device)
        outputs = []

        for t in range(self.max_len):
            logits = self.decoder(ys, memory)           # (B, t+1, V)
            next_token = logits[:, -1].argmax(-1, keepdim=True)
            outputs.append(next_token)

            if tgt_ids is not None and random() < teacher_forcing:
                ys = torch.cat([ys, tgt_ids[:, t:t+1]], dim=1)
            else:
                ys = torch.cat([ys, next_token], dim=1)

            if (next_token == self.eos).all():
                break

        return torch.cat(outputs, dim=1)      # 形如 (B, L)

用法示例

enc = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=512, nhead=8), num_layers=6)
dec = nn.TransformerDecoder(nn.TransformerDecoderLayer(d_model=512, nhead=8), num_layers=6)
model = Seq2Seq(enc, dec, sos_id=0, eos_id=2)

6 进阶主题

方向 思路 代表工作
长上下文 稀疏/线性注意力(Performer, Longformer) LongT5, Flash-Attention
检索增强 (RAG) 外部向量数据库返回候选段落,拼接进解码器输入 RETRO, Atlas, LlamaIndex
多模态对齐 视觉/音频编码器 + 文本解码器;对比学习统一 token 空间 BLIP-2, Gemini, GPT-4o
效率优化 混合精度、蒸馏、小模型教师、KV 缓存、模型并行 DeepSpeed ZeRO-3, Flash-Decoding

7 小结与实践建议

  1. 架构是方法论:编码器负责理解,解码器负责表达,二者可自由组合。
  2. 先跑通,再混搭:先用官方 Transformer 教程跑 NMT baseline,再尝试 LSTM-Enc + Transformer-Dec 等混搭,体会差异。
  3. 关注长上下文与检索增强:RAG 正成为工业搜索-生成系统的主流范式。
  4. 做项目,反推理论:挑一项真实业务(如 PDF 摘要、邮 件分类),落地一条端到端流水线,遇到痛点再查论文,理解会更深。
Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐