扩散模型一文讲解

背景

图像生成领域最常见生成模型有GAN和VAE,2020年,DDPM(Denoising Diffusion Probabilistic Model)被提出,被称为扩散模型(Diffusion Model),同样可用于图像生成。近年扩散模型大热,Stability AI、OpenAI、Google Brain等相继基于扩散模型提出的以文生图,图像生成视频生成等模型。

原理介绍

扩散模型:和其他生成模型一样,实现从噪声(采样自简单的分布)生成目标数据样本。

扩散模型包括两个过程:前向过程(forward process)和反向过程(reverse process),其中前向过程又称为扩散过程(diffusion process)。无论是前向过程还是反向过程都是一个参数化的马尔可夫链(Markov chain),其中反向过程可用于生成数据样本(它的作用类似GAN中的生成器,只不过GAN生成器会有维度变化,而DDPM的反向过程没有维度变化)。

在这里插入图片描述在这里插入图片描述

- x0到xt 为逐步加噪过的前向程,噪声是已知的,该过程从原始图片逐步加噪至一组纯噪声。

- xt到x0 为将一组随机噪声还原为输入的过程。该过程需要学习一个去噪过程,直到还原一张图片。

前向过程

前向过程是加噪的过程,前向过程中图像 只和上一时刻的 有关, 该过程可以视为马尔科夫过程, 满足:
q ( x 1 : x T ∣ x 0 ) = ∏ t = 1 T q ( x t ∣ x t − 1 ) q(x_1:x_T|x_0) = \prod_{t=1}^{T} q(x_t|x_{t-1}) q(x1:xTx0)=t=1Tq(xtxt1)

q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q(x_t|x_{t-1}) = \mathcal{N}(x_t; \sqrt{1 - \beta_t} x_{t-1}, \beta_t I) q(xtxt1)=N(xt;1βt xt1,βtI)

q(x1:T | x0)表示从原始图像x0 开始,经过 T步加噪过程后得到噪声图像 的概率分布。

每一步q(x1:T | x0) 表示在第 t步,给定前一步的图像xt-1 ,得到当前图像xt 的条件概率分布,这是一个高斯分布,其均值为
1 − β t x t − 1 \sqrt{1 - \beta_t} x_{t-1} 1βt xt1
,方差为
β t \beta_t βt
,其中
β t \beta_t βt
是一个预定的噪声方差系数, I是单位矩阵。

其中不同的 t是
β t \beta_t βt
预先定义好的,由时间1~T 逐渐的递增,可以是Linear,Cosine等,满足 :
β 1 < β 2 < β 3 < . . . < β T \beta_1<\beta_2<\beta_3<...<\beta_T β1<β2<β3<...<βT
根据以上公式,可以通过重参数化采样得到。经过推导,可以得出 xt与 x0的关系。

在这里插入图片描述

Unet职责

无论在前向过程还是反向过程,Unet的职责都是根据当前的样本和时间t预测噪声。

在这里插入图片描述

Gaussion Diffusion职责

前向过程:从1到T的时间采样一个时间t ,生成一个随机噪声加到图片上,从Unet获取预测噪声,计算损失后更新Unet梯度

反向过程:先从正态分布随机采样和训练样本一样大小的纯噪声图片,从T-1到0逐步重复以下步骤:从xt 还原 xt-1。

训练过程

在这里插入图片描述

Algorithm1:Training

  • 从数据中抽取一个样本,
  • 从1-T中随机选取一个时间t
  • 将 和t传给GaussionDiffusion,GaussionDiffusion采样一个随机噪声,加到 ,形成 ,然后将 和t放入Unet,Unet根据t生成正弦位置编码和 结合,Unet预测加的这个噪声,并返回噪声,GaussionDiffusion计算该噪声和随机噪声的损失
  • 将神经网络Unet预测的噪声与之前GaussionDiffusion采样的随机噪声求L2损失,计算梯度,更新权重。
  • 重复以上步骤,直到网络Unet训练完成。

训练步骤中每个模块的交互如下图:

在这里插入图片描述

Algorithm2:Sampling

  • - 从标准正态分布采样出 xT
  • - 依次重复以下步骤:
  • - 从标准正态分布采样 ,为重参数化做准备
  • - 根据模型求出x,结合xt 和采样得到z利用重参数化技巧,得到 xt-1
  • - 循环结束后返回 x0

采样步骤中每个模块的交互如下图:

在这里插入图片描述

结合代码(MindSpore版本)讲解

代码主要分为以下几块:Unet、GaussianDiffusion、 Trainer

1. Unet

Unet网络结构如图:

在这里插入图片描述

1.1 正弦位置编码

DDPM每步训练是随机采样一个时间,为了让网络知道当前处理的是一系列去噪过程中的哪一个step,我们需要将当前t编码并传入网络之中,DDPM使用的Unet是time-condition Unet。

类似于Transformer的positional embedding,DDPM采用正弦位置编码(Sinusoidal Positional Embeddings),既需要位置编码有界又需要两个时间步长之间的距离与句子长度无关。为了满足这两点标准,一种思路是使用有界的周期性函数,而简单的有界周期性函数很容易想到sin和cos函数。

class SinusoidalPosEmb(nn.Cell):
    def __init__(self, dim):
        super().__init__()
        half_dim = dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = np.exp(np.arange(half_dim) * - emb)
        self.emb = Tensor(emb, mindspore.float32)
        self.Concat = _get_cache_prim(ops.Concat)(-1)

    def construct(self, x):
        emb = x[:, None] * self.emb[None, :]
        emb = self.Concat((ops.sin(emb), ops.cos(emb)))
        return emb

DDPM的Unet有ResidualBlock和Attention Module

1.2 AttentionAttention的本质是从人类视觉注意力机制中获得灵感。大致是我们视觉在感知东西的时候,一般不会是一个场景从到头看到尾每次全部都看,而往往是根据需求观察注意特定的一部分。具体可以参考博客:TheLongGoodbye:浅谈Attention机制的理解

class Attention(nn.Cell):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads

        self.to_qkv = _get_cache_prim(Conv2d)(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False)
        self.to_out = _get_cache_prim(Conv2d)(hidden_dim, dim, 1, pad_mode='valid', has_bias=True)
self.map = ops.Map()
        self.partial = ops.Partial()
        self.bmm = BMM()
        self.split = ops.Split(axis=1, output_num=3)
        self.softmax = ops.Softmax(-1)

    def construct(self, x):
        b, c, h, w = x.shape
        qkv = self.split(self.to_qkv(x))
        q, k, v = self.map(self.partial(rearrange, self.heads), qkv)
        q = q * self.scale
        sim = self.bmm(q.swapaxes(2, 3), k)
        attn = self.softmax(sim)
        out = self.bmm(attn, v.swapaxes(2, 3))
        out = out.swapaxes(-1, -2).reshape((b, -1, h, w))
        return self.to_out(out)

1.3 Residual Block

是ResNet的核心模块,可以防止网络退化。

class Residual(nn.Cell):
    """残差块"""
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def construct(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

2. GaussianDiffusion

首先定义相关的概率值,与公式相对应:

self.betas = betas
        self.alphas_cumprod = alphas_cumprod
        self.alphas_cumprod_prev = alphas_cumprod_prev

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.sqrt_alphas_cumprod = Tensor(np.sqrt(alphas_cumprod))
        self.sqrt_one_minus_alphas_cumprod = Tensor(np.sqrt(1. - alphas_cumprod))
        self.log_one_minus_alphas_cumprod = Tensor(np.log(1. - alphas_cumprod))
        self.sqrt_recip_alphas_cumprod = Tensor(np.sqrt(1. / alphas_cumprod))
        self.sqrt_recipm1_alphas_cumprod = Tensor(np.sqrt(1. / alphas_cumprod - 1))

        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

        self.posterior_variance = Tensor(posterior_variance)

        self.posterior_log_variance_clipped = Tensor(
            np.log(np.clip(posterior_variance, 1e-20, None)))
        self.posterior_mean_coef1 = Tensor(
            betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        self.posterior_mean_coef2 = Tensor(
            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))

        p2_loss_weight = (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod))\
                          ** - p2_loss_weight_gamma
        self.p2_loss_weight = Tensor(p2_loss_weight)

计算损失

基于Unet预测出noise,使用预测noise和真实noise计算损失:

def p_losses(self, x_start, t, noise, random_cond):
	    # 生成的真实noise
	    x = self.q_sample(x_start=x_start, t=t, noise=noise)
	    # if doing self-conditioning, 50% of the time, predict x_start from current set of times
	    if self.self_condition:
	        if random_cond:
	            _, x_self_cond = self.model_predictions(x, t)
	            x_self_cond = ops.stop_gradient(x_self_cond)
	        else:
	            x_self_cond = ops.zeros_like(x)
	    else:
	        x_self_cond = ops.zeros_like(x)
	    # model_out为基于U-net预测的pred_noise,此处self.model为Unet,ddpm默认预测目标是pred_noise。
	    model_out = self.model(x, t, x_self_cond)
	    if self.objective == 'pred_noise':
	        target = noise
	    elif self.objective == 'pred_x0':
	        target = x_start
	    elif self.objective == 'pred_v':
	        v = self.predict_v(x_start, t, noise)
	        target = v
	    else:
	        target = noise
		# 计算损失值
	    loss = self.loss_fn(model_out, target)
	    loss = loss.reshape(loss.shape[0], -1)
	    loss = loss * extract(self.p2_loss_weight, t, loss.shape)
	    return loss.mean()

采样

输出x_start,也就是原始图像,当sampling_time_steps< time_steps,用下方函数:

def ddim_sample(self, shape, clip_denoise=True):
    batch = shape[0]
    total_timesteps, sampling_timesteps, = self.num_timesteps, self.sampling_timesteps
    eta, objective = self.ddim_sampling_eta, self.objective

    # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
    times = np.linspace(-1, total_timesteps - 1, sampling_timesteps + 1).astype(np.int32)
    # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
    times = list(reversed(times.tolist()))
    time_pairs = list(zip(times[:-1], times[1:]))

	# 采样第一次迭代,Unet输入img为随机采样
    img = np.random.randn(*shape).astype(np.float32)
    x_start = None

    for time, time_next in tqdm(time_pairs, desc='sampling loop time step'):
        # time_cond = ops.fill(mindspore.int32, (batch,), time)
        time_cond = np.full((batch,), time).astype(np.int32)
        x_start = Tensor(x_start) if x_start is not None else x_start
        self_cond = x_start if self.self_condition else None
        predict_noise, x_start, *_ = self.model_predictions(Tensor(img, mindspore.float32),
                                                            Tensor(time_cond),
                                                            self_cond,
                                                            clip_denoise)
        predict_noise, x_start = predict_noise.asnumpy(), x_start.asnumpy()
        if time_next < 0:
            img = x_start
            continue

        alpha = self.alphas_cumprod[time]
        alpha_next = self.alphas_cumprod[time_next]

        sigma = eta * np.sqrt(((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)))
        c = np.sqrt(1 - alpha_next - sigma ** 2)

        noise = np.random.randn(*img.shape)

        img = x_start * np.sqrt(alpha_next) + c * predict_noise + sigma * noise

    img = self.unnormalize(img)

    return img

3. Trainer 训练器

data_iterator中每次取出的数据集就是一个batch_size大小,每训练一个batch,self.step就会加1。

DDPM的trainer采用ema(指数移动平均)优化,ema不参与训练,只参与推理,比对变量直接赋值而言,移动平均得到的值在图像上更加平缓光滑,抖动性更小。具体代码参考代码仓中ema.py

print('training start')
        with tqdm(initial=self.step, total=self.train_num_steps, disable=False) as pbar:
            total_loss = 0.
            for (img,) in data_iterator:
                model.set_train()
                # # 随机采样time向量
                time_emb = Tensor(
                    np.random.randint(0, num_timesteps, (img.shape[0],)).astype(np.int32))
                noise = Tensor(np.random.randn(*img.shape), mindspore.float32)
                # 返回损失、计算梯度、更新梯度
                self_cond = random.random() < 0.5 if self.self_condition else False
                loss = train_step(img, time_emb, noise, self_cond)

                # 损失累加
                total_loss += float(loss.asnumpy())

                self.step += 1
                if self.step % gradient_accumulate_every == 0:
                    # ema和model的参数同步更新
                    self.ema.update()
                    pbar.set_description(f'loss: {total_loss:.4f}')
                    pbar.update(1)
                    total_loss = 0.

                accumulate_step = self.step // gradient_accumulate_every
                accumulate_remain_step = self.step % gradient_accumulate_every
                if self.step != 0 and accumulate_step % self.save_and_sample_every == 0\
                        and accumulate_remain_step == 0:

                    self.ema.set_train(False)
                    self.ema.synchronize()
                    batches = num_to_groups(self.num_samples, self.batch_size)
                    all_images_list = list(map(lambda n: self.ema.online_model.sample(batch_size=n),
                                               batches))
                    self.save_images(all_images_list, accumulate_step)
self.save(accumulate_step)
                    self.ema.desynchronize()

                if self.step >= gradient_accumulate_every * self.train_num_steps:
                    break

        print('training complete')

DDPM论文

- [Denoising Diffusion Probabilistic Models](U-Net: Convolutional Networks for Biomedical Image Segmentation)

代码链接

代码参考:1、MindSpore官方迁移实现:https://github.com/lvyufeng/denoising-diffusion-mindspore

2、pytorch:GitHub - lucidrains/denoising-diffusion-pytorch: Implementation of Denoising Diffusion Probabilistic Model in Pytorch

代码链接:GitHub - drizzlezyk/DDPM-MindSpore

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐