Progressive Distillation for Fast Sampling of Diffusion Models (Section2.4)
V-prediction refers to a type of loss objective in diffusion model. It was first introduced by the Brain team of Google Research for fast sampling diffusion models in 2022. Instead of estimating the noise
V-prediction re-parameterizes the diffusion models in a way that the implied prediction
MindONE supports v-prediction for Stable Diffusion models based on MindSpore and Ascend platforms.
Please make sure the following frameworks are installed.
- mindspore >= 1.9 [install] (2.0 is recommended for the best performance.)
- python >= 3.7
- openmpi 4.0.3 (for distributed training/evaluation) [install]
Install the dependent packages by running:
pip install -r requirements.txt
Please download the pretrained SD2.0-768-v checkpoint and put it under stable_diffusion_v2/models
folder.
The text-image pair dataset for finetuning should follow the file structure below
dir
├── img1.jpg
├── img2.jpg
├── img3.jpg
└── img_txt.csv
img_txt.csv is the annotation file in the following format
dir,text
img1.jpg,a cartoon character with a potted plant on his head
img2.jpg,a drawing of a green pokemon with red eyes
img3.jpg,a red and white ball with an angry look on its face
For convenience, we have prepared two public text-image datasets obeying the above format.
- pokemon-blip-caption dataset, containing 833 pokemon-style images with BLIP-generated captions.
- Chinese-art blip caption dataset, containing 100 chinese art-style images with BLIP-generated captions.
To use them, please download pokemon_blip.zip
and chinese_art_blip.zip
from the openi dataset website. Then unzip them on your local directory, e.g. ./datasets/pokemon_blip
.
We will use the train_text_to_image.py
script for v-prediciton finetuning.
Before running, please make sure the image_size
is set to 768
and please modify the following arguments to your
local path in the shell or in the config file train_config_vanilla_v2_vpred.yaml
:
--data_path=/path/to/data
--output_path=/path/to/save/output_data
--pretrained_model_path=/path/to/pretrained_model
Then, execute the script to launch finetuning:
python train_text_to_image.py \
--train_config "configs/train/train_config_vanilla_v2_vpred.yaml" \
--data_path "datasets/pokemon_blip/train" \
--output_path "output/vpred_vanilla_finetune_pokemon/txt2img" \
--pretrained_model_path "models/sd_v2_768_v-e12e3a9b.ckpt"
After training, the finetuned checkpoint will be saved in {output_path}/ckpt/txt2img/ckpt/rank_0/sd-72.ckpt
.
Below are some arguments that you may want to tune for a better performance on your dataset:
- train_batch_size: the number of batch size for training.
- start_learning_rate and end_learning_rate: the initial and end learning rates for training.
- epochs: the number of epochs for training.
- use_ema: whether use EMA for model smoothing
For more argument illustration, please run python train_text_to_image.py -h
.
To perform text-to-image generation with the finetuned v-prediction checkpoint, fisrst modify configs/v2-inference.yaml
as follows to switch from eps-prediction
to v-prediction
,
model:
prediction_type: "v"
Then run
python text_to_image.py \
--prompt "A drawing of a fox with a red tail" \
--config configs/v2-inference.yaml \
--output_path ./output/ \
--W 768 \
--H 768 \
--ckpt_path {path/to/v_prediction_checkpoint_after_finetune}
Please update ckpt_path
according to your finetune settings.
Here are the example results.
Images generated by Stable Diffusion 2.0 v-prediction fine-tuned on pokemon-blip dataset
Images generated by Stable Diffusion 2.0 v-prediction fine-tuned on chinese-art-blip dataset
We will evaluate the finetuned model on the split test set in pokemon_blip.zip
and chinese_art_blip.zip
.
Let us run text-to-image generation conditioned on the prompts in test set then evaluate the quality of the generated images by the following steps.
- Before running, please modify the following arguments to your local path:
--data_path=/path/to/prompts.txt
--output_path=/path/to/save/output_data
--ckpt_path=/path/to/model_checkpoint
prompts.txt
is a file which contains multiple prompts, and each line is the caption for a real image in test set, for example
a drawing of a spider on a white background
a drawing of a pokemon with blue eyes
a drawing of a pokemon pokemon with its mouth open
...
- Run multiple-prompt inference on the test set
python text_to_image.py \
--version "2.0" \
--prompt "a wolf in winter" \
--config configs/v2-inference.yaml \
--output_path output/ \
--seed 42 \
--n_iter 4 \
--n_samples 1 \
--W 768 \
--H 768 \
--sampling_steps 15 \
--dpm_solver \
--scale 9 \
--ckpt_path "models/sd_v2_768_v-e12e3a9b.ckpt"
The generated images will be saved in the {output_path}/samples
folder.
Note that the following hyper-param configuration will affect the generation and evaluation results.
- sampler: the diffusion sampler
- sampling_steps: the sampling steps
- scale: unconditional guidance scale
For more details, please run
- Evaluate the generated images
python eval/eval_fid.py --real_dir {path/to/test_images} --gen_dir {path/to/generated_images}
python eval/eval_clip_score.py --image_path {path/to/test_images} --prompt {path/to/prompts_file} --load_checkpoint {path/to/checkpoint}
For details, please refer to the guideline Diffusion Evaluation.
Here are the evaluation results for our implementation, where the batch size is 1
and finetune method is vanilla for all experiments.
Pretrained Model | Dataset | objective | image size | Sampling Algo. | FID | CLIP Score |
---|---|---|---|---|---|---|
stable_diffusion_2.0_base | chinese_art_blip | eps-prediction | 512*512 | DPM Solver (scale: 9, steps: 15) | 265 | 34.9 |
stable_diffusion_2.0_base | pokemon_blip | eps-prediction | 512*512 | DPM Solver (scale: 9, steps: 15) | 116 | 32 |
stable_diffusion_2.0_768_v | chinese_art_blip | v-prediction | 768*768 | DPM Solver (scale: 9, steps: 15) | 267 | 34.3 |
stable_diffusion_2.0_768_v | pokemon_blip | v-prediction | 768*768 | DPM Solver (scale: 9, steps: 15) | 120 | 31.9 |
Note that these numbers can not reflect the generation quality comprehensively!! A visual evaluation is also necessary.
[1] Progressive Distillation for Fast Sampling of Diffusion Models (Section2.4)