11.7 Diffusion Model
扩散模型(Diffusion Model)是一类生成模型,其核心思想是通过逐步添加噪声再逐步去噪,从而学习到复杂数据的分布并生成新的样本。
11.7.1 原理
在前向过程中,设原始样本为\(x_0\),通过\(T\)步逐步加噪,得到\(x_1, x_2, \dots, x_T\)。
每一步定义为:
\[ q(x_t | x_{t-1}) = \mathcal{N}\left(x_t; \sqrt{1 - \beta_t}x_{t-1}, \, \beta_t I \right) \]
其中:
- \(\beta_t \in (0, 1)\) 为每步噪声强度;
- \(I\) 为单位协方差矩阵。
通过多步叠加,最终数据会趋近标准高斯分布:
\[ x_T \sim \mathcal{N}(0, I) \]
利用高斯链式性质,可以直接写出任意时刻 \(t\) 的显式表达式:
\[ q(x_t | x_0) = \mathcal{N}\left(x_t; \sqrt{\bar{\alpha_t}}x_0, \, (1 - \bar{\alpha_t})I \right) \]
其中:
\[ \alpha_t = 1 - \beta_t, \quad \bar{\alpha_t} = \prod_{s=1}^t \alpha_s \]
这意味着我们可以一次性将任意样本 \(x_0\) 加噪为第 \(t\) 步状态 \(x_t\),无需逐步生成。
之后在反向过程中,从纯噪声 \(x_T \sim \mathcal{N}(0, I)\) 开始,学习反向马尔可夫链:
\[ p_\theta(x_{t-1} | x_t) = \mathcal{N}\left(x_{t-1}; \mu_\theta(x_t, t), \, \Sigma_\theta(x_t, t)\right) \]
由于真实的 \(q(x_{t-1}|x_t)\) 不可得,我们用神经网络 \(\epsilon_\theta(x_t, t)\) 来近似噪声分布, 并将均值项重写为:
\[ \mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha_t}}} \, \epsilon_\theta(x_t, t) \right) \]
训练目标是最小化模型预测噪声与真实噪声之间的均方误差:
\[ L(\theta) = \mathbb{E}_{x_0, t, \epsilon}\left[ \| \epsilon - \epsilon_\theta(x_t, t) \|^2 \right] \]
即让模型在任意加噪程度下,准确预测噪声成分。
实现逻辑:
- 随机采样时间步 \(t\);
- 采样噪声 \(\epsilon \sim \mathcal{N}(0, I)\);
- 一次性计算加噪样本: \[ x_t = \sqrt{\bar{\alpha_t}}x_0 + \sqrt{1 - \bar{\alpha_t}}\epsilon \]
- 训练模型预测噪声: \[ \epsilon_\theta(x_t, t) \approx \epsilon \]
- 最小化预测误差。
训练完成后,从标准高斯噪声开始逆向去噪:
\[ x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha_t}}} \epsilon_\theta(x_t, t) \right) + \sigma_t z, \quad z \sim \mathcal{N}(0, I) \]
循环从 \(t = T\) 到 \(t = 1\),即可生成新的样本 \(x_0\)。
11.7.2 示例
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
# ========================
# 1️⃣ 定义超参数
# ========================
device = "cuda" if torch.cuda.is_available() else "cpu"
T = 1000 # 扩散步数
beta_start, beta_end = 1e-4, 0.02
betas = torch.linspace(beta_start, beta_end, T).to(device)
alphas = 1.0 - betas
alpha_bars = torch.cumprod(alphas, dim=0)
# ========================
# 2️⃣ 构造一个简单数据集(1维高斯分布)
# ========================
N = 1000
x0 = torch.randn(N, 1).to(device) * 2 + 3 # 数据分布:N(3, 4)
plt.hist(x0.cpu().numpy(), bins=50, density=True)
plt.title("Training data distribution")
plt.show()
# ========================
# 3️⃣ 前向加噪函数 q(x_t | x_0)
# ========================
def q_sample(x0, t, noise=None):
if noise is None:
noise = torch.randn_like(x0)
sqrt_ab = torch.sqrt(alpha_bars[t]).view(-1, 1)
sqrt_one_minus_ab = torch.sqrt(1 - alpha_bars[t]).view(-1, 1)
return sqrt_ab * x0 + sqrt_one_minus_ab * noise
# ========================
# 4️⃣ 定义去噪网络 ε_θ(x_t, t)
# ========================
class DenoiseNet(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(2, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, 1)
)
def forward(self, x_t, t):
# 将时间步归一化后拼接输入
t_embed = t.float().unsqueeze(1) / T
x_in = torch.cat([x_t, t_embed], dim=1)
return self.net(x_in)
model = DenoiseNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# ========================
# 5️⃣ 训练过程
# ========================
epochs = 1000
for epoch in range(epochs):
t = torch.randint(0, T, (N,), device=device) # 随机时间步
noise = torch.randn_like(x0)
x_t = q_sample(x0, t, noise)
noise_pred = model(x_t, t)
loss = F.mse_loss(noise_pred, noise)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 100 == 0:
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
# ========================
# 6️⃣ 采样(从纯噪声逆扩散生成)
# ========================
@torch.no_grad()
def p_sample(model, x_t, t):
beta_t = betas[t]
alpha_t = alphas[t]
alpha_bar_t = alpha_bars[t]
noise_pred = model(x_t, torch.tensor([t]*x_t.shape[0], device=device))
mean = (1 / torch.sqrt(alpha_t)) * (
x_t - (1 - alpha_t) / torch.sqrt(1 - alpha_bar_t) * noise_pred
)
if t > 0:
z = torch.randn_like(x_t)
else:
z = 0
return mean + torch.sqrt(beta_t) * z
@torch.no_grad()
def sample(model, n_samples=1000):
x_t = torch.randn(n_samples, 1).to(device)
for t in reversed(range(T)):
x_t = p_sample(model, x_t, t)
return x_t
samples = sample(model, n_samples=1000).cpu().numpy()
# ========================
# 7️⃣ 可视化结果
# ========================
plt.figure(figsize=(8,5))
plt.hist(x0.cpu().numpy(), bins=50, density=True, alpha=0.5, label="Real data")
plt.hist(samples, bins=50, density=True, alpha=0.5, label="Generated data")
plt.legend()
plt.title("DDPM training vs sampling result")
plt.show()