11.7 Diffusion Model

扩散模型(Diffusion Model)是一类生成模型,其核心思想是通过逐步添加噪声再逐步去噪,从而学习到复杂数据的分布并生成新的样本。

11.7.1 原理

在前向过程中,设原始样本为\(x_0\),通过\(T\)步逐步加噪,得到\(x_1, x_2, \dots, x_T\)

每一步定义为:

\[ q(x_t | x_{t-1}) = \mathcal{N}\left(x_t; \sqrt{1 - \beta_t}x_{t-1}, \, \beta_t I \right) \]

其中:

  • \(\beta_t \in (0, 1)\) 为每步噪声强度;
  • \(I\) 为单位协方差矩阵。

通过多步叠加,最终数据会趋近标准高斯分布:

\[ x_T \sim \mathcal{N}(0, I) \]

利用高斯链式性质,可以直接写出任意时刻 \(t\) 的显式表达式:

\[ q(x_t | x_0) = \mathcal{N}\left(x_t; \sqrt{\bar{\alpha_t}}x_0, \, (1 - \bar{\alpha_t})I \right) \]

其中:

\[ \alpha_t = 1 - \beta_t, \quad \bar{\alpha_t} = \prod_{s=1}^t \alpha_s \]

这意味着我们可以一次性将任意样本 \(x_0\) 加噪为第 \(t\) 步状态 \(x_t\),无需逐步生成。

之后在反向过程中,从纯噪声 \(x_T \sim \mathcal{N}(0, I)\) 开始,学习反向马尔可夫链:

\[ p_\theta(x_{t-1} | x_t) = \mathcal{N}\left(x_{t-1}; \mu_\theta(x_t, t), \, \Sigma_\theta(x_t, t)\right) \]

由于真实的 \(q(x_{t-1}|x_t)\) 不可得,我们用神经网络 \(\epsilon_\theta(x_t, t)\) 来近似噪声分布, 并将均值项重写为:

\[ \mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha_t}}} \, \epsilon_\theta(x_t, t) \right) \]

训练目标是最小化模型预测噪声与真实噪声之间的均方误差:

\[ L(\theta) = \mathbb{E}_{x_0, t, \epsilon}\left[ \| \epsilon - \epsilon_\theta(x_t, t) \|^2 \right] \]

即让模型在任意加噪程度下,准确预测噪声成分。

实现逻辑:

  1. 随机采样时间步 \(t\)
  2. 采样噪声 \(\epsilon \sim \mathcal{N}(0, I)\)
  3. 一次性计算加噪样本: \[ x_t = \sqrt{\bar{\alpha_t}}x_0 + \sqrt{1 - \bar{\alpha_t}}\epsilon \]
  4. 训练模型预测噪声: \[ \epsilon_\theta(x_t, t) \approx \epsilon \]
  5. 最小化预测误差。

训练完成后,从标准高斯噪声开始逆向去噪:

\[ x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha_t}}} \epsilon_\theta(x_t, t) \right) + \sigma_t z, \quad z \sim \mathcal{N}(0, I) \]

循环从 \(t = T\)\(t = 1\),即可生成新的样本 \(x_0\)

11.7.2 示例

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

# ========================
# 1️⃣ 定义超参数
# ========================
device = "cuda" if torch.cuda.is_available() else "cpu"
T = 1000                       # 扩散步数
beta_start, beta_end = 1e-4, 0.02
betas = torch.linspace(beta_start, beta_end, T).to(device)
alphas = 1.0 - betas
alpha_bars = torch.cumprod(alphas, dim=0)

# ========================
# 2️⃣ 构造一个简单数据集(1维高斯分布)
# ========================
N = 1000
x0 = torch.randn(N, 1).to(device) * 2 + 3  # 数据分布:N(3, 4)
plt.hist(x0.cpu().numpy(), bins=50, density=True)
plt.title("Training data distribution")
plt.show()

# ========================
# 3️⃣ 前向加噪函数 q(x_t | x_0)
# ========================
def q_sample(x0, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x0)
    sqrt_ab = torch.sqrt(alpha_bars[t]).view(-1, 1)
    sqrt_one_minus_ab = torch.sqrt(1 - alpha_bars[t]).view(-1, 1)
    return sqrt_ab * x0 + sqrt_one_minus_ab * noise

# ========================
# 4️⃣ 定义去噪网络 ε_θ(x_t, t)
# ========================
class DenoiseNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
    def forward(self, x_t, t):
        # 将时间步归一化后拼接输入
        t_embed = t.float().unsqueeze(1) / T
        x_in = torch.cat([x_t, t_embed], dim=1)
        return self.net(x_in)

model = DenoiseNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# ========================
# 5️⃣ 训练过程
# ========================
epochs = 1000
for epoch in range(epochs):
    t = torch.randint(0, T, (N,), device=device)  # 随机时间步
    noise = torch.randn_like(x0)
    x_t = q_sample(x0, t, noise)

    noise_pred = model(x_t, t)
    loss = F.mse_loss(noise_pred, noise)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

# ========================
# 6️⃣ 采样(从纯噪声逆扩散生成)
# ========================
@torch.no_grad()
def p_sample(model, x_t, t):
    beta_t = betas[t]
    alpha_t = alphas[t]
    alpha_bar_t = alpha_bars[t]
    noise_pred = model(x_t, torch.tensor([t]*x_t.shape[0], device=device))
    mean = (1 / torch.sqrt(alpha_t)) * (
        x_t - (1 - alpha_t) / torch.sqrt(1 - alpha_bar_t) * noise_pred
    )
    if t > 0:
        z = torch.randn_like(x_t)
    else:
        z = 0
    return mean + torch.sqrt(beta_t) * z

@torch.no_grad()
def sample(model, n_samples=1000):
    x_t = torch.randn(n_samples, 1).to(device)
    for t in reversed(range(T)):
        x_t = p_sample(model, x_t, t)
    return x_t

samples = sample(model, n_samples=1000).cpu().numpy()

# ========================
# 7️⃣ 可视化结果
# ========================
plt.figure(figsize=(8,5))
plt.hist(x0.cpu().numpy(), bins=50, density=True, alpha=0.5, label="Real data")
plt.hist(samples, bins=50, density=True, alpha=0.5, label="Generated data")
plt.legend()
plt.title("DDPM training vs sampling result")
plt.show()