-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathRepairTrainer.py
More file actions
300 lines (264 loc) · 14 KB
/
RepairTrainer.py
File metadata and controls
300 lines (264 loc) · 14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
import logging
import os
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from ImageMetrics import ImageMetrics
from torch.cuda import amp
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -------------------------- 5. 训练工具(优化:感知损失 + 学习率调度器)--------------------------
class RepairTrainer:
def __init__(self, generator, discriminator, lr=0.0002, beta1=0.5, perceptual_weight=10.0,
n_critic=5, output_dir='.', use_mask_for_generator_input=False, use_amp=False):
self.G = generator.to(device)
self.D = discriminator.to(device)
self.lr = lr
self.beta1 = beta1
self.perceptual_weight = perceptual_weight
self.n_critic = n_critic
self.output_dir = output_dir # 新增:输出目录
self.use_mask_for_generator_input = use_mask_for_generator_input
self.use_amp = bool(use_amp) and device.type == 'cuda'
# 损失函数(优化:添加感知损失)
self.bce_loss = nn.BCEWithLogitsLoss()
self.l1_loss = nn.L1Loss()
self.l1_weight = 100.0 # L1损失权重
# 优化器
self.opt_G = torch.optim.Adam(self.G.parameters(), lr=lr, betas=(beta1, 0.999))
self.opt_D = torch.optim.Adam(self.D.parameters(), lr=lr, betas=(beta1, 0.999))
self.scaler = amp.GradScaler(enabled=self.use_amp)
# 学习率调度器(余弦退火,平滑降低学习率)
self.scheduler_G = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt_G, T_max=50, eta_min=1e-6)
self.scheduler_D = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt_D, T_max=50, eta_min=1e-6)
# 初始化评估指标计算器(包含VGG用于感知损失)
self.metrics = ImageMetrics()
# 训练记录(新增感知损失)
self.logs = {
"d_loss_real": [], "d_loss_fake": [], "g_loss_total": [],
"g_loss_gan": [], "g_loss_l1": [], "g_loss_perceptual": [],
"epoch_fid": [], "epoch_ssim": [], "lr_G": [], "lr_D": []
}
def _build_generator_input(self, corrupted_img, mask=None):
if not self.use_mask_for_generator_input:
return corrupted_img
if mask is None:
raise ValueError("Generator expects mask input but mask is None")
mask_channel = mask[:, :1] if mask.size(1) > 1 else mask
hole_mask = (mask_channel < 0.5).float()
return torch.cat([corrupted_img, hole_mask], dim=1)
def train_step(self, real_img, corrupted_img, mask, batch_idx):
"""单步训练(优化:添加感知损失,调整D/G训练比例)"""
batch_size = real_img.size(0)
# --- 训练Discriminator ---
self.opt_D.zero_grad(set_to_none=True)
scaler_updated = False
with amp.autocast(enabled=self.use_amp):
d_out_real = self.D(real_img)
# 动态生成匹配Discriminator实际输出尺寸的标签
real_label = torch.ones_like(d_out_real, device=device)
d_loss_real = self.bce_loss(d_out_real, real_label)
gen_input = self._build_generator_input(corrupted_img, mask)
fake_img = self.G(gen_input)
d_out_fake = self.D(fake_img.detach())
fake_label = torch.zeros_like(d_out_fake, device=device)
d_loss_fake = self.bce_loss(d_out_fake, fake_label)
d_loss_total = (d_loss_real + d_loss_fake) / 2
self.scaler.scale(d_loss_total).backward()
self.scaler.step(self.opt_D)
# 记录判别器损失
self.logs["d_loss_real"].append(d_loss_real.item())
self.logs["d_loss_fake"].append(d_loss_fake.item())
# --- 训练Generator (每n_critic次D的训练后进行一次) ---
if (batch_idx + 1) % self.n_critic == 0:
self.opt_G.zero_grad(set_to_none=True)
with amp.autocast(enabled=self.use_amp):
fake_img_for_g = self.G(self._build_generator_input(corrupted_img, mask))
d_out_fake_for_g = self.D(fake_img_for_g)
# 使用与d_out_fake_for_g相同形状的标签
real_label_for_g = torch.ones_like(d_out_fake_for_g, device=device)
g_loss_gan = self.bce_loss(d_out_fake_for_g, real_label_for_g)
g_loss_l1 = self.l1_loss(fake_img_for_g, real_img) * self.l1_weight
with amp.autocast(enabled=False):
g_loss_perceptual = self.metrics.compute_perceptual_loss(
fake_img_for_g.float(), real_img.float()
) * self.perceptual_weight
g_loss_total = g_loss_gan + g_loss_l1 + g_loss_perceptual
self.scaler.scale(g_loss_total).backward()
self.scaler.step(self.opt_G)
self.scaler.update()
scaler_updated = True
# 记录生成器损失
self.logs["g_loss_total"].append(g_loss_total.item())
self.logs["g_loss_gan"].append(g_loss_gan.item())
self.logs["g_loss_l1"].append(g_loss_l1.item())
self.logs["g_loss_perceptual"].append(g_loss_perceptual.item())
if not scaler_updated:
self.scaler.update()
return fake_img.detach() # 返回分离的图像,避免影响后续计算
def train_step_with_mask(self, real_img, corrupted_img, mask, batch_idx):
"""
支持FFC生成器的训练步骤(需要mask输入)
用于FFCGeneratorWrapper包装的生成器
"""
batch_size = real_img.size(0)
# --- 训练Discriminator ---
self.opt_D.zero_grad()
# 先计算判别器输出
d_out_real = self.D(real_img)
# 动态生成匹配Discriminator实际输出尺寸的标签
real_label = torch.ones_like(d_out_real, device=device)
d_loss_real = self.bce_loss(d_out_real, real_label)
# FFC生成器需要mask参数
fake_img = self.G(corrupted_img, mask)
d_out_fake = self.D(fake_img.detach())
# 使用与d_out_fake相同形状的标签
fake_label_for_d = torch.zeros_like(d_out_fake, device=device)
d_loss_fake = self.bce_loss(d_out_fake, fake_label_for_d)
d_loss_total = (d_loss_real + d_loss_fake) / 2
d_loss_total.backward()
self.opt_D.step()
# 记录判别器损失
self.logs["d_loss_real"].append(d_loss_real.item())
self.logs["d_loss_fake"].append(d_loss_fake.item())
# --- 训练Generator (每n_critic次D的训练后进行一次) ---
if (batch_idx + 1) % self.n_critic == 0:
self.opt_G.zero_grad()
# 重新生成图像并计算损失
fake_img_for_g = self.G(corrupted_img, mask)
d_out_fake_for_g = self.D(fake_img_for_g)
# 使用与d_out_fake_for_g相同形状的标签(生成器希望判别器认为是真的)
real_label_for_g = torch.ones_like(d_out_fake_for_g, device=device)
g_loss_gan = self.bce_loss(d_out_fake_for_g, real_label_for_g)
g_loss_l1 = self.l1_loss(fake_img_for_g, real_img) * self.l1_weight
g_loss_perceptual = self.metrics.compute_perceptual_loss(fake_img_for_g, real_img) * self.perceptual_weight
g_loss_total = g_loss_gan + g_loss_l1 + g_loss_perceptual
g_loss_total.backward()
self.opt_G.step()
# 记录生成器损失
self.logs["g_loss_total"].append(g_loss_total.item())
self.logs["g_loss_gan"].append(g_loss_gan.item())
self.logs["g_loss_l1"].append(g_loss_l1.item())
self.logs["g_loss_perceptual"].append(g_loss_perceptual.item())
return fake_img.detach()
def step_schedulers(self):
"""每个epoch结束后调用,更新学习率"""
self.scheduler_G.step()
self.scheduler_D.step()
# 记录当前学习率
self.logs["lr_G"].append(self.opt_G.param_groups[0]['lr'])
self.logs["lr_D"].append(self.opt_D.param_groups[0]['lr'])
@torch.no_grad()
def evaluate_epoch(self, dataloader, use_ffc=False):
"""新增:评估当前epoch的FID和SSIM(全数据集评估,确保准确性)"""
self.G.eval() # 生成器切换到评估模式(BatchNorm固定)
all_real_features = [] # 存储所有真实图像的Inception特征
all_fake_features = [] # 存储所有生成图像的Inception特征
all_ssim_scores = [] # 存储所有样本的SSIM
for real_img, corrupted_img, mask in dataloader:
# 数据转移到设备
real_img = real_img.to(device)
corrupted_img = corrupted_img.to(device)
mask = mask.to(device)
# 生成修复图像(根据生成器类型选择不同的输入方式)
if use_ffc:
fake_img = self.G(corrupted_img, mask)
else:
gen_input = self._build_generator_input(corrupted_img, mask)
fake_img = self.G(gen_input)
# 1. 计算SSIM(当前batch)
batch_ssim = self.metrics.compute_ssim(real_img, fake_img)
all_ssim_scores.append(batch_ssim)
# 2. 提取Inception特征(用于FID)
real_feat = self.metrics.compute_inception_features(real_img)
fake_feat = self.metrics.compute_inception_features(fake_img)
all_real_features.append(real_feat)
all_fake_features.append(fake_feat)
# 合并所有batch的特征和SSIM
all_real_features = np.concatenate(all_real_features, axis=0)
all_fake_features = np.concatenate(all_fake_features, axis=0)
epoch_ssim = np.mean(all_ssim_scores)
epoch_fid = self.metrics.compute_fid(all_real_features, all_fake_features)
# 记录当前epoch的指标
self.logs["epoch_ssim"].append(epoch_ssim)
self.logs["epoch_fid"].append(epoch_fid)
# 打印评估结果
logging.info(f"当前Epoch评估结果 | FID: {epoch_fid:.2f} | SSIM: {epoch_ssim:.4f}")
self.G.train() # 恢复训练模式
return epoch_fid, epoch_ssim
def plot_metrics(self):
"""绘制损失曲线 + FID/SSIM曲线 + 学习率曲线(六图布局),并保存到输出目录"""
fig, axes = plt.subplots(3, 2, figsize=(15, 18))
# 1. Discriminator损失
axes[0, 0].plot(self.logs["d_loss_real"], label="D_loss_real", alpha=0.7)
axes[0, 0].plot(self.logs["d_loss_fake"], label="D_loss_fake", alpha=0.7)
axes[0, 0].set_title("Discriminator Loss")
axes[0, 0].legend()
axes[0, 0].grid(True)
# 2. Generator损失(含感知损失)
axes[0, 1].plot(self.logs["g_loss_total"], label="G_loss_total", alpha=0.7)
axes[0, 1].plot(self.logs["g_loss_gan"], label="G_loss_gan", alpha=0.7)
axes[0, 1].plot(self.logs["g_loss_l1"], label="G_loss_l1", alpha=0.7)
axes[0, 1].plot(self.logs["g_loss_perceptual"], label="G_loss_perceptual", alpha=0.7)
axes[0, 1].set_title("Generator Loss (含感知损失)")
axes[0, 1].legend()
axes[0, 1].grid(True)
# 3. FID曲线(FID越小越好,标注数值)
axes[1, 0].plot(range(1, len(self.logs["epoch_fid"]) + 1), self.logs["epoch_fid"],
label="FID", color="red", marker="o")
axes[1, 0].set_title("FID Curve (Lower Better)")
axes[1, 0].set_xlabel("Epoch")
axes[1, 0].set_ylabel("FID Value")
axes[1, 0].legend()
axes[1, 0].grid(True)
# 标注每个epoch的FID值
for i, fid in enumerate(self.logs["epoch_fid"]):
axes[1, 0].annotate(f"{fid:.1f}", (i + 1, fid), textcoords="offset points",
xytext=(0, 10), ha='center')
# 4. SSIM曲线(SSIM越接近1越好,标注数值)
axes[1, 1].plot(range(1, len(self.logs["epoch_ssim"]) + 1), self.logs["epoch_ssim"],
label="SSIM", color="green", marker="s")
axes[1, 1].set_title("SSIM Curve (Higher Better)")
axes[1, 1].set_xlabel("Epoch")
axes[1, 1].set_ylabel("SSIM Value (0~1)")
axes[1, 1].legend()
axes[1, 1].grid(True)
# 标注每个epoch的SSIM值
for i, ssim_val in enumerate(self.logs["epoch_ssim"]):
axes[1, 1].annotate(f"{ssim_val:.3f}", (i + 1, ssim_val), textcoords="offset points",
xytext=(0, 10), ha='center')
# 5. 学习率曲线(Generator)
axes[2, 0].plot(range(1, len(self.logs["lr_G"]) + 1), self.logs["lr_G"],
label="LR_G", color="blue", marker="^")
axes[2, 0].set_title("Generator Learning Rate")
axes[2, 0].set_xlabel("Epoch")
axes[2, 0].set_ylabel("Learning Rate")
axes[2, 0].legend()
axes[2, 0].grid(True)
axes[2, 0].set_yscale('log')
# 6. 学习率曲线(Discriminator)
axes[2, 1].plot(range(1, len(self.logs["lr_D"]) + 1), self.logs["lr_D"],
label="LR_D", color="orange", marker="v")
axes[2, 1].set_title("Discriminator Learning Rate")
axes[2, 1].set_xlabel("Epoch")
axes[2, 1].set_ylabel("Learning Rate")
axes[2, 1].legend()
axes[2, 1].grid(True)
axes[2, 1].set_yscale('log')
# 调整布局
plt.tight_layout()
save_path = os.path.join(self.output_dir, "metrics.png")
plt.savefig(save_path)
plt.close()
logging.info(f"指标曲线图已保存至: {save_path}")
def save_model(self, save_path):
"""保存模型(新增:同时保存指标日志)"""
os.makedirs(os.path.dirname(save_path), exist_ok=True)
torch.save({
"G_state_dict": self.G.state_dict(),
"D_state_dict": self.D.state_dict(),
"opt_G_state_dict": self.opt_G.state_dict(),
"opt_D_state_dict": self.opt_D.state_dict(),
"training_logs": self.logs # 保存训练日志(含指标)
}, save_path)
logging.info(f"模型及日志已保存至:{save_path}")