别再死记硬背公式了!用Python代码一步步‘画’出DDPM扩散模型的核心推导

张开发
2026/4/16 1:00:18 15 分钟阅读

分享文章

别再死记硬背公式了!用Python代码一步步‘画’出DDPM扩散模型的核心推导
用Python代码可视化DDPM扩散模型从数学公式到动态图像生成在深度学习领域扩散模型正迅速成为生成式AI的核心技术之一。但对于许多开发者来说那些复杂的数学推导常常让人望而生畏——β_t、α_t、高斯分布叠加...这些抽象符号背后究竟发生了什么本文将带你用Python代码将这些公式画出来通过可交互的视觉呈现让扩散模型的核心原理变得触手可及。1. 环境准备与基础工具1.1 安装必要依赖我们需要以下Python库来构建可视化系统pip install torch torchvision matplotlib numpy ipywidgets1.2 创建基础工具函数首先构建一些可视化辅助工具import torch import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation from IPython.display import HTML import numpy as np def plot_images(images, titlesNone, figsize(15, 5)): 可视化图像序列 plt.figure(figsizefigsize) for i, img in enumerate(images): plt.subplot(1, len(images), i1) plt.imshow(img.permute(1, 2, 0) if img.dim() 3 else img, cmapgray) plt.axis(off) if titles: plt.title(titles[i]) plt.show() def animate_diffusion(images, interval200): 创建扩散过程动画 fig plt.figure(figsize(6,6)) plt.axis(off) im plt.imshow(images[0], cmapgray) def update(frame): im.set_array(images[frame]) return [im] return FuncAnimation(fig, update, frameslen(images), intervalinterval, blitTrue)2. 前向扩散过程的可视化实现2.1 噪声调度策略DDPM的核心是设计合理的噪声添加计划def linear_beta_schedule(timesteps, beta_start1e-4, beta_end0.02): 线性噪声调度表 return torch.linspace(beta_start, beta_end, timesteps) def cosine_beta_schedule(timesteps, s0.008): 余弦噪声调度表 steps timesteps 1 x torch.linspace(0, timesteps, steps) alphas_cumprod torch.cos(((x / timesteps) s) / (1 s) * torch.pi * 0.5) ** 2 alphas_cumprod alphas_cumprod / alphas_cumprod[0] betas 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.999) # 对比不同调度策略 timesteps 1000 beta_linear linear_beta_schedule(timesteps) beta_cosine cosine_beta_schedule(timesteps) plt.plot(beta_linear.numpy(), labelLinear) plt.plot(beta_cosine.numpy(), labelCosine) plt.xlabel(Timestep); plt.ylabel(β value) plt.title(Noise Schedule Comparison) plt.legend(); plt.show()2.2 单步加噪实现让我们用代码实现前向扩散的一个步骤def forward_diffusion_sample(x0, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, devicecpu): 单步前向扩散 noise torch.randn_like(x0) sqrt_alphas_cumprod_t sqrt_alphas_cumprod[t].reshape(-1, 1, 1, 1) sqrt_one_minus_alphas_cumprod_t sqrt_one_minus_alphas_cumprod[t].reshape(-1, 1, 1, 1) return sqrt_alphas_cumprod_t * x0 sqrt_one_minus_alphas_cumprod_t * noise, noise2.3 完整前向过程可视化现在让我们观察图像如何逐步被噪声破坏# 加载示例图像 from torchvision.datasets import CIFAR10 dataset CIFAR10(root./data, trainTrue, downloadTrue) sample_img torch.tensor(dataset[0][0]).permute(2, 0, 1).float() / 255 # 准备噪声调度 timesteps 50 betas linear_beta_schedule(timesteps) alphas 1. - betas alphas_cumprod torch.cumprod(alphas, axis0) sqrt_alphas_cumprod torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod torch.sqrt(1. - alphas_cumprod) # 生成扩散序列 diffusion_steps [sample_img.unsqueeze(0)] for t in range(timesteps): with torch.no_grad(): xt, _ forward_diffusion_sample( diffusion_steps[-1], torch.tensor([t]), sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod ) diffusion_steps.append(xt.squeeze()) # 可视化关键步骤 selected_steps [0, 10, 20, 30, 40, 49] plot_images([diffusion_steps[i] for i in selected_steps], titles[fStep {i} for i in selected_steps])3. 反向去噪过程的核心实现3.1 简易UNet噪声预测器构建一个简化版的UNet来预测噪声import torch.nn as nn import torch.nn.functional as F class SimpleUNet(nn.Module): def __init__(self, in_channels3): super().__init__() # 下采样 self.down1 self.block(in_channels, 64) self.down2 self.block(64, 128) self.down3 self.block(128, 256) # 上采样 self.up1 self.block(256 128, 128) # 跳跃连接 self.up2 self.block(128 64, 64) self.up3 self.block(64 in_channels, in_channels) self.pool nn.MaxPool2d(2) self.upsample nn.Upsample(scale_factor2, modebilinear, align_cornersTrue) def block(self, in_channels, out_channels): return nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, padding1), nn.BatchNorm2d(out_channels), nn.ReLU(), nn.Conv2d(out_channels, out_channels, 3, padding1), nn.BatchNorm2d(out_channels), nn.ReLU() ) def forward(self, x): # 下采样路径 x1 self.down1(x) x2 self.down2(self.pool(x1)) x3 self.down3(self.pool(x2)) # 上采样路径 x self.upsample(x3) x torch.cat([x, x2], dim1) x self.up1(x) x self.upsample(x) x torch.cat([x, x1], dim1) x self.up2(x) x self.upsample(x) x self.up3(x) return x3.2 反向采样关键步骤实现从噪声图像逐步恢复原始图像的过程torch.no_grad() def reverse_diffusion_sample(model, x, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, posterior_variance, devicecpu): 单步反向去噪 betas 1 - (sqrt_alphas_cumprod**2 / sqrt_alphas_cumprod**2).roll(1) betas[0] 1 - (sqrt_alphas_cumprod[0]**2) # 预测噪声 model.eval() predicted_noise model(x) # 计算均值 sqrt_recip_alphas torch.sqrt(1.0 / (1 - betas)) mean sqrt_recip_alphas[t] * (x - betas[t] * predicted_noise / sqrt_one_minus_alphas_cumprod[t]) # 计算方差 if t 0: return mean else: posterior_variance_t posterior_variance[t] noise torch.randn_like(x) return mean torch.sqrt(posterior_variance_t) * noise3.3 完整反向过程演示让我们观察模型如何逐步去除噪声# 初始化模型和参数 device cuda if torch.cuda.is_available() else cpu model SimpleUNet().to(device) posterior_variance betas * (1. - alphas_cumprod.roll(1)) / (1. - alphas_cumprod) posterior_variance[0] betas[0] # 从纯噪声开始 x torch.randn_like(sample_img.unsqueeze(0)).to(device) reverse_steps [x.cpu().squeeze()] # 逐步去噪 for t in reversed(range(timesteps)): x reverse_diffusion_sample( model, x, torch.tensor([t], devicedevice), sqrt_alphas_cumprod.to(device), sqrt_one_minus_alphas_cumprod.to(device), posterior_variance.to(device), device ) reverse_steps.append(x.cpu().squeeze()) # 可视化去噪过程 selected_reverse_steps [0, 10, 20, 30, 40, 49] plot_images([reverse_steps[i] for i in selected_reverse_steps], titles[fReverse Step {i} for i in selected_reverse_steps])4. 训练过程与损失函数4.1 噪声预测损失DDPM的核心是训练网络预测添加到图像中的噪声def p_losses(model, x0, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, devicecpu): 计算噪声预测损失 # 生成随机噪声 noise torch.randn_like(x0) # 前向扩散过程 x_noisy, _ forward_diffusion_sample( x0, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, device ) # 预测噪声 predicted_noise model(x_noisy) # 计算损失 return F.mse_loss(noise, predicted_noise)4.2 训练循环实现完整的训练过程实现def train_diffusion_model(model, dataset, timesteps200, batch_size32, epochs10, lr1e-3): 训练扩散模型 device cuda if torch.cuda.is_available() else cpu model model.to(device) optimizer torch.optim.Adam(model.parameters(), lrlr) # 准备噪声调度 betas linear_beta_schedule(timesteps) alphas 1. - betas alphas_cumprod torch.cumprod(alphas, axis0) sqrt_alphas_cumprod torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod torch.sqrt(1. - alphas_cumprod) # 数据加载器 loader torch.utils.data.DataLoader(dataset, batch_sizebatch_size, shuffleTrue) for epoch in range(epochs): for step, (x0, _) in enumerate(loader): x0 x0.to(device) # 随机采样时间步 t torch.randint(0, timesteps, (x0.shape[0],), devicedevice).long() # 计算损失 loss p_losses( model, x0, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, device ) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() if step % 100 0: print(fEpoch {epoch} | Step {step} | Loss: {loss.item():.4f}) return model5. 高级可视化与交互探索5.1 动态扩散过程展示创建交互式扩散过程演示from ipywidgets import interact, IntSlider def show_diffusion_step(step0): plt.figure(figsize(6,6)) plt.imshow(diffusion_steps[step].permute(1,2,0)) plt.title(fDiffusion Step {step}) plt.axis(off) plt.show() interact(show_diffusion_step, stepIntSlider(min0, maxtimesteps, step1, value0))5.2 噪声调度影响分析比较不同噪声调度对生成质量的影响def compare_schedules(): schedules { Linear: linear_beta_schedule(timesteps), Cosine: cosine_beta_schedule(timesteps), Quadratic: torch.linspace(1e-4, 0.02**0.5, timesteps)**2 } fig, axes plt.subplots(1, 3, figsize(15, 5)) for ax, (name, beta) in zip(axes, schedules.items()): alpha 1 - beta alpha_cumprod torch.cumprod(alpha, dim0) ax.plot(alpha_cumprod.numpy()) ax.set_title(name) ax.set_xlabel(Timestep) ax.set_ylabel(α cumprod) plt.tight_layout() plt.show() compare_schedules()5.3 潜在空间探索可视化模型在不同噪声水平下的表现def explore_latent_space(model, num_samples5): model.eval() fig, axes plt.subplots(num_samples, timesteps//10, figsize(15, 8)) for i in range(num_samples): x torch.randn(1, 3, 32, 32).to(device) for j, t in enumerate(reversed(range(0, timesteps, timesteps//10))): with torch.no_grad(): x reverse_diffusion_sample( model, x, torch.tensor([t], devicedevice), sqrt_alphas_cumprod.to(device), sqrt_one_minus_alphas_cumprod.to(device), posterior_variance.to(device), device ) axes[i,j].imshow(x.cpu().squeeze().permute(1,2,0).clip(0,1)) axes[i,j].axis(off) if i 0: axes[i,j].set_title(ft{t}) plt.tight_layout() plt.show() explore_latent_space(model)

更多文章