-
Notifications
You must be signed in to change notification settings - Fork 0
/
deband_trainer.py
79 lines (74 loc) · 3.49 KB
/
deband_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import lightning as L
from torch import optim
from torchmetrics.image import PeakSignalNoiseRatio
from utils import write_images
from utils import get_indi_step,indi_transform, sample
class IndiDeband(L.LightningModule):
def __init__(self,
net,
setup,
loss,
steps = 10,
):
super().__init__()
self.net = net
self.setup = setup
self.loss = loss
self.steps = steps
self.psnr = PeakSignalNoiseRatio()
def training_step(self, batch, batch_idx):
dry, wet = batch['dry'], batch['wet']
# get features from network
fct, t = get_indi_step(dry, deterministic = False, steps=self.steps)
transformed = indi_transform(fct.to(self.setup.trainDevice), clean=dry.to(self.setup.trainDevice),
dirty=wet.to(self.setup.trainDevice))
repair = self.net(transformed.to(self.setup.trainDevice), t.to(self.setup.trainDevice))
if self.loss.transform != None:
loss = (self.loss.fn(self.loss.transform(repair).to(self.setup.trainDevice),
self.loss.transform(dry).to(self.setup.trainDevice))) * self.loss.weight
else:
loss = (self.loss.fn(repair.to(self.setup.trainDevice),
dry.to(self.setup.trainDevice))) * self.loss.weight
total = loss
if self.setup.args.debug:
print(total)
self.log("train_loss", total, on_step=False, on_epoch=True)
psnr = self.psnr(repair, dry)
self.log("train_psnr", psnr, on_step=False, on_epoch=True)
return total
def validation_step(self, batch, batch_idx):
dry, wet = batch['dry'], batch['wet']
# get features from network
# test returns all 1s for proper evaluation
fct, t = get_indi_step(dry, steps=self.steps, test=True)
transformed = indi_transform(fct.to(self.setup.trainDevice), clean=dry.to(self.setup.trainDevice),
dirty=wet.to(self.setup.trainDevice))
repair = self.net(transformed.to(self.setup.trainDevice), t.to(self.setup.trainDevice))
if self.loss.transform != None:
loss = (self.loss.fn(self.loss.transform(repair).to(self.setup.trainDevice),
self.loss.transform(dry).to(self.setup.trainDevice))) * self.loss.weight
else:
loss = self.loss.fn(repair, dry) * self.loss.weight
total = loss
d = dry.clone()
w = wet.clone()
# set each time, we only care about the last (for now)
self.image_map = {"dry": d, "wet": w}
if self.setup.args.debug:
print(total)
self.log("val_loss", total, on_step=False, on_epoch=True)
psnr = self.psnr(repair, dry)
self.log("test_psnr", psnr, on_step=False, on_epoch=True)
return total
def on_validation_epoch_end(self):
repair = sample(self.net, self.image_map["wet"], self.steps)
self.image_map["repair"] = repair
writer = self.logger.experiment
# writes audio to whichever experiment logger we are using.
write_images(self.image_map, writer=writer, epoch=self.current_epoch)
return None
# default to adam.
def configure_optimizers(self, optimizer=optim.Adam):
#opt = optim.Adamax(self.parameters(), lr=self.lr)
optimizer = optimizer(self.parameters(), lr= self.setup.args.learningRate)
return optimizer