基于PyTorch实现的条件潜在扩散模型,用于医学图像模态转换(T1 MRI → T2 MRI)。
- 🏥 医学图像专用: 专门针对MRI T1到T2模态转换优化
- 🚀 高效架构: 使用VAE潜在空间降低计算复杂度
- 🎯 条件生成: 基于T1图像条件生成对应的T2图像
- 📊 完整评估: 包含PSNR、SSIM等医学图像质量评估指标
- ⚡ 快速推理: 支持DDIM采样器实现快速生成
- 🔧 易于使用: 提供完整的训练和推理pipeline
T1 图像 → VAE编码器 → 潜在表示 → 条件扩散模型 → 生成潜在表示 → VAE解码器 → T2 图像
-
VAE (变分自编码器)
- 将256×256图像编码到32×32×4潜在空间
- 大幅降低扩散模型的计算复杂度
- 支持高质量图像重建
-
条件UNet
- 基于注意力机制的UNet架构
- 交叉注意力层融合T1条件信息
- 时间步嵌入支持扩散过程
-
扩散模型
- 支持DDPM和DDIM采样
- 可配置的噪声调度策略
- 稳定的训练过程
python quick_start.py --mode demopython quick_start.py --mode full --epochs 20python quick_start.py --mode train --epochs 50python quick_start.py --mode inferencepip install -r requirements.txt或使用快速开始脚本自动安装:
python quick_start.py --mode demo- 创建数据目录结构:
data/
├── T1/ # T1 MRI图像 (PNG格式)
│ ├── image001.png
│ ├── image002.png
│ └── ...
└── T2/ # 对应的T2 MRI图像 (PNG格式)
├── image001.png
├── image002.png
└── ...
- 确保T1和T2图像文件名一一对应
- 图像格式:PNG,灰度图,推荐分辨率256×256
python train.py --mode full --epochs 100 --lr 1e-4第一阶段:训练VAE
python train.py --mode vae --epochs 50 --lr 1e-4 --kl_weight 1e-6第二阶段:训练扩散模型
python train.py --mode diffusion --epochs 100 --lr 1e-4 --resume_vae checkpoints/vae_best.pth--mode: 训练模式 (vae/diffusion/full)--epochs: 训练轮数--lr: 学习率--batch_size: 批次大小--kl_weight: VAE的KL散度损失权重--data_split: 训练/验证集分割比例--resume_vae: 恢复VAE检查点--resume_diffusion: 恢复扩散模型检查点
python inference.py \
--mode single \
--vae_path checkpoints/vae_best.pth \
--diffusion_path checkpoints/diffusion_best.pth \
--input path/to/t1_image.png \
--output path/to/generated_t2.pngpython inference.py \
--mode batch \
--vae_path checkpoints/vae_best.pth \
--diffusion_path checkpoints/diffusion_best.pth \
--input data/T1/ \
--output outputs/generated/python inference.py \
--mode evaluate \
--vae_path checkpoints/vae_best.pth \
--diffusion_path checkpoints/diffusion_best.pth \
--input data/T1/ \
--t2_dir data/T2/ \
--output outputs/evaluation/--use_ddim: 使用DDIM采样器(默认开启)--num_inference_steps: 推理步数(默认50步)--num_samples: 评估时的样本数量限制
主要配置在 config.py 中:
class Config:
# 数据配置
IMAGE_SIZE = 256
BATCH_SIZE = 8
# 模型配置
LATENT_DIM = 4
TIMESTEPS = 1000
# 训练配置
LEARNING_RATE = 1e-4
NUM_EPOCHS = 100系统提供以下评估指标:
- PSNR (峰值信噪比): 衡量图像重建质量
- SSIM (结构相似性): 评估图像结构保持程度
评估结果会保存为JSON格式:
{
"psnr_mean": 28.45,
"psnr_std": 2.31,
"ssim_mean": 0.892,
"ssim_std": 0.045,
"num_samples": 100
}训练过程中会生成:
checkpoints/: 模型检查点logs/: 训练日志和TensorBoard文件outputs/samples/: 训练过程中的生成样本outputs/reconstructions/: VAE重建结果outputs/comparisons/: 对比图像
- 编码器: 4层下采样,ResNet块 + 注意力机制
- 解码器: 4层上采样,对称结构
- 潜在维度: 4通道,32×32分辨率
- 输入/输出: 4通道潜在表示
- 时间嵌入: 正弦位置编码
- 条件融合: 交叉注意力机制
- 注意力层: 在32×32、16×16、8×8分辨率
- 前向过程: 线性或余弦噪声调度
- 反向过程: 学习噪声预测
- 采样: 支持DDPM和DDIM
-
内存优化:
- 使用梯度检查点
- 调整批次大小
- 启用混合精度训练
-
训练加速:
- 使用多GPU训练
- 预训练VAE权重
- 调整学习率调度
-
质量提升:
- 增加训练数据
- 调整损失函数权重
- 使用数据增强
-
CUDA内存不足:
# 减小批次大小 python train.py --batch_size 4 -
数据加载错误:
- 检查图像文件格式和命名
- 确保T1/T2图像配对正确
-
训练不稳定:
- 降低学习率
- 调整KL散度权重
- 使用梯度裁剪
- 支持3D MRI数据
- 多模态条件生成
- 实时推理优化
- Web界面部署
- 医学图像预处理工具
如果您使用了本项目,请考虑引用相关论文:
@article{rombach2022high,
title={High-resolution image synthesis with latent diffusion models},
author={Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj{\"o}rn},
journal={CVPR},
year={2022}
}MIT License
如有问题或建议,请提交Issue或Pull Request。