-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathctm_train.py
102 lines (86 loc) · 3.23 KB
/
ctm_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from tqdm import tqdm
import torch
from ctm.ctm import ConsistencyTrajectoryModel
from ctm.toy_tasks.data_generator import DataGenerator
from ctm.visualization.vis_utils import plot_main_figure
"""
Discrete consistency distillation training of the consistency model on a toy task.
We train a diffusion model and the consistency model at the same time and iteratively
update the weights of the consistency model and the diffusion model.
"""
if __name__ == "__main__":
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
n_sampling_steps = 10
use_pretraining = True
cm = ConsistencyTrajectoryModel(
data_dim=1,
cond_dim=1,
sampler_type='ddim',
lr=1e-4,
sigma_data=0.5,
sigma_min=0.05,
solver_type='heun',
sigma_max=5,
n_discrete_t=18,
conditioned=False,
diffusion_lambda= 1,
device=device,
rho=7,
ema_rate=0.999,
use_teacher=use_pretraining,
)
train_epochs = 2003
# chose one of the following toy tasks: 'three_gmm_1D' 'uneven_two_gmm_1D' 'two_gmm_1D' 'single_gaussian_1D'
data_manager = DataGenerator('two_gmm_1D')
samples, cond = data_manager.generate_samples(5000)
samples = samples.reshape(-1, 1).to(device)
cond = cond.to(device)
pbar = tqdm(range(train_epochs))
# if not simultanous_training:
# First pretrain the diffusion model and then train the consistency model
if use_pretraining:
for i in range(train_epochs):
cond = cond.reshape(-1, 1).to(device)
diff_loss = cm.diffusion_train_step(samples, cond, i, train_epochs)
pbar.set_description(f"Step {i}, Diff Loss: {diff_loss:.8f}")
pbar.update(1)
cm.update_teacher_model()
plot_main_figure(
data_manager.compute_log_prob,
cm,
200,
train_epochs,
sampling_method='euler',
n_sampling_steps=n_sampling_steps,
x_range=[-4, 4],
save_path='./plots/'
)
# Train the consistency trajectory model either simultanously with the diffusion model or after pretraining
for i in range(train_epochs):
cond = cond.reshape(-1, 1).to(device)
loss, cmt_loss, diffusion_loss, gan_loss = cm.train_step(samples, cond)
pbar.set_description(f"Step {i}, Loss: {loss:.8f}, CTM Loss: {cmt_loss:.8f}, Diff Loss: {diffusion_loss:.8f}, GAN Loss: {gan_loss:.8f}")
pbar.update(1)
# Plotting the results of the training
# We do this for the one-step and the multi-step sampler to compare the results
if not use_pretraining:
plot_main_figure(
data_manager.compute_log_prob,
cm,
200,
train_epochs,
sampling_method='euler',
n_sampling_steps=n_sampling_steps,
x_range=[-4, 4],
save_path='./plots/'
)
plot_main_figure(
data_manager.compute_log_prob,
cm,
200,
train_epochs,
sampling_method='onestep',
x_range=[-4, 4],
save_path='./plots/ctm'
)
print('done')