Day 16: 生成模型基础 (Generative Models Basics)

摘要:在前面的课程中,我们学习的分类、检测、分割任务都属于判别式模型(判断“是什么”)。而今天,我们将进入生成式模型(创造“有什么”)的奇妙世界。从“左右互搏”的 GAN 到“概率重构”的 VAE,本文将带你推开 AI 内容生成(AIGC)的大门。


1. 判别式 vs 生成式

在深入具体模型之前,先搞清楚两种建模思路的区别:

模型类型 英文 核心逻辑 例子 比喻
判别式模型 Discriminative 学习条件概率 $P(Y X)$ 分类、回归
生成式模型 Generative 学习联合概率 P(X,Y)P(X,Y)P(X,Y) 或分布 P(X)P(X)P(X) 图像生成、文本生成 画家(从白纸画出一幅画)

简单来说,判别式模型在乎决策边界,生成式模型在乎数据分布


2. GAN:生成对抗网络 (Generative Adversarial Networks)

2014年,Ian Goodfellow 提出了 GAN,被 Yann LeCun 誉为“过去十年机器学习领域最有趣的想法”。

2.1 核心思想:零和博弈

GAN 由两个网络组成,它们就像造假币的团伙警察

  • 生成器 (Generator, G):负责“造假”。输入随机噪声 zzz,输出假图片 G(z)G(z)G(z)。目标是骗过判别器。
  • 判别器 (Discriminator, D):负责“打假”。输入图片 xxx,判断它是真的(来自数据集)还是假的(来自 GGG)。

2.2 训练过程

两者交替训练,最终达到纳什均衡:

  1. 固定 G,训练 D:让 D 能尽可能区分真图和假图。
  2. 固定 D,训练 G:让 G 生成的假图能被 D 误判为真。

min⁡Gmax⁡DV(D,G)=Ex∼pdata(x)[log⁡D(x)]+Ez∼pz(z)[log⁡(1−D(G(z)))] \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_{z}(z)}[\log (1 - D(G(z)))] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

2.3 训练的痛点

  • 模式崩塌 (Mode Collapse):G 发现生成某种特定的图很容易骗过 D,于是它就一直生成这一种图,失去了多样性。
  • 训练不稳定:两者很难平衡,如果 D 太强,G 梯度消失学不动;如果 G 太强,D 随便猜,也没法指导 G。

3. 经典 GAN 变体

为了解决原始 GAN 的问题,无数变体涌现:

3.1 DCGAN (Deep Convolutional GAN)

  • 贡献:把 CNN 引入 GAN,并给出了一系列工程实践建议(如去掉 Pooling 用 Strided Conv,用 BatchNorm,用 LeakyReLU 等)。
  • 地位:GAN 走向实用化的第一步。

3.2 WGAN (Wasserstein GAN)

  • 贡献:从理论上解决了训练不稳定的问题。用 Wasserstein 距离(推土机距离)代替原来的 JS 散度来衡量分布差异。
  • 特点:训练超稳定,不用小心翼翼调节 D 和 G 的平衡。

3.3 StyleGAN

  • 贡献:生成人脸的霸主。
  • 核心:引入了 Style Transfer 的思想,通过映射网络(Mapping Network)把噪声解耦,控制生成的不同层级特征(粗糙特征控制脸型,精细特征控制发色)。

3.4 BigGAN

  • 贡献:简单粗暴,“大”就是好。更大的模型、更大的 Batch Size,生成了高分辨率、高质量的 ImageNet 图像。

4. VAE:变分自编码器 (Variational Autoencoders)

如果说 GAN 是“无中生有”,VAE 则是“先压后解”。它基于贝叶斯推断,是一个有着严谨数学推导的模型。

4.1 结构

  1. Encoder (推断网络):输入图片 xxx,输出隐变量 zzz 的分布(均值 μ\muμ 和方差 σ\sigmaσ)。
  2. Decoder (生成网络):从分布中采样 zzz,重构出图片 x^\hat{x}x^

4.2 核心技巧:重参数化 (Reparameterization Trick)

  • 问题:直接从 N(μ,σ2)N(\mu, \sigma^2)N(μ,σ2) 采样这个操作是不可导的,没法反向传播。
  • 解决:把采样变成 z=μ+σ⋅ϵz = \mu + \sigma \cdot \epsilonz=μ+σϵ,其中 ϵ∼N(0,1)\epsilon \sim N(0, 1)ϵN(0,1)。这样 ϵ\epsilonϵ 是常数,μ\muμσ\sigmaσ 就可以求导了。

4.3 损失函数

Loss=Reconstruction Loss+KL Divergence Loss = \text{Reconstruction Loss} + \text{KL Divergence} Loss=Reconstruction Loss+KL Divergence

  • 重构损失:生成的图要像原图。
  • KL 散度:隐变量 zzz 的分布要尽可能接近标准正态分布 N(0,1)N(0, 1)N(0,1)(为了方便采样)。

4.4 GAN vs VAE

  • GAN:生成的图清晰度高,但多样性可能差,训练难。
  • VAE:生成的图容易模糊(因为是概率分布的均值),但分布特性好,训练稳定。

5. Flow-based Models (流模型)

这是一个比较小众但优雅的流派(如 Glow)。

  • 核心:设计一系列可逆的变换函数 fff
  • 优点:可以精确计算似然概率 P(x)P(x)P(x),生成的隐变量极其平滑。
  • 缺点:为了保证可逆,网络结构受限,计算量巨大。

6. 代码实践:PyTorch 实现最简单的 GAN

这里用 MNIST 数据集演示一个全连接层的 GAN。

import torch
import torch.nn as nn

# 1. 生成器
class Generator(nn.Module):
    def __init__(self, z_dim=64, img_dim=784):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, img_dim),
            nn.Tanh()  # 输出归一化到 [-1, 1]
        )

    def forward(self, z):
        return self.gen(z)

# 2. 判别器
class Discriminator(nn.Module):
    def __init__(self, img_dim=784):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid() # 输出概率
        )

    def forward(self, img):
        return self.disc(img)

# 3. 训练循环伪代码
# for epoch in epochs:
#     for real_imgs, _ in dataloader:
#         
#         ### 训练判别器 ###
#         noise = torch.randn(batch_size, z_dim)
#         fake_imgs = gen(noise)
#         
#         real_score = disc(real_imgs)
#         fake_score = disc(fake_imgs.detach()) # 注意 detach,不传梯度给 G
#         
#         d_loss = -torch.mean(torch.log(real_score) + torch.log(1 - fake_score))
#         opt_disc.zero_grad()
#         d_loss.backward()
#         opt_disc.step()
#
#         ### 训练生成器 ###
#         # G 希望 D 判定 fake_imgs 为真 (score 接近 1)
#         fake_score = disc(fake_imgs) 
#         g_loss = -torch.mean(torch.log(fake_score))
#         
#         opt_gen.zero_grad()
#         g_loss.backward()
#         opt_gen.step()

7. 总结

  • GAN 是生成领域的“摇滚明星”,效果惊艳但脾气暴躁(难训练)。
  • VAE 是“数学教授”,理论优美但有时过于平滑(图像模糊)。
  • Flow 是“精密仪器”,精确但昂贵。
  • 预告:虽然 GAN 统治了一段时间,但明天我们要介绍的 扩散模型 (Diffusion Model) 才是当今 AIGC 真正的王者。它结合了概率生成的稳定性(像 VAE)和生成质量的高保真度(超越 GAN)。

思考:为什么 GAN 的 Loss 很难指示训练进度?(因为它是两个网络的博弈,Loss 下降不代表生成质量变好,可能只是对手变弱了。)

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐