-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
93 lines (83 loc) · 3.84 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torchnet as tnt
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader
from torchnet.engine import Engine
from torchnet.logger import VisdomPlotLogger
from tqdm import tqdm
from data_utils import DatasetFromFolder
from model import Net
def processor(sample):
data, target, training = sample
data = Variable(data)
target = Variable(target)
if torch.cuda.is_available():
data = data.cuda()
target = target.cuda()
output = model(data)
loss = criterion(output, target)
return loss, output
def on_sample(state):
state['sample'].append(state['train'])
def reset_meters():
meter_psnr.reset()
meter_loss.reset()
def on_forward(state):
meter_psnr.add(state['output'].data, state['sample'][1])
meter_loss.add(state['loss'].item())
def on_start_epoch(state):
reset_meters()
scheduler.step()
state['iterator'] = tqdm(state['iterator'])
def on_end_epoch(state):
print('[Epoch %d] Train Loss: %.4f (PSNR: %.2f db)' % (
state['epoch'], meter_loss.value()[0], meter_psnr.value()))
train_loss_logger.log(state['epoch'], meter_loss.value()[0])
train_psnr_logger.log(state['epoch'], meter_psnr.value())
reset_meters()
engine.test(processor, val_loader)
val_loss_logger.log(state['epoch'], meter_loss.value()[0])
val_psnr_logger.log(state['epoch'], meter_psnr.value())
print('[Epoch %d] Val Loss: %.4f (PSNR: %.2f db)' % (
state['epoch'], meter_loss.value()[0], meter_psnr.value()))
torch.save(model.state_dict(), 'epochs/epoch_%d_%d.pt' % (UPSCALE_FACTOR, state['epoch']))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Train Super Resolution')
parser.add_argument('--upscale_factor', default=2, type=int, help='super resolution upscale factor')
parser.add_argument('--num_epochs', default=100, type=int, help='super resolution epochs number')
opt = parser.parse_args()
UPSCALE_FACTOR = opt.upscale_factor
NUM_EPOCHS = opt.num_epochs
torch.cuda.set_device(0)
torch.cuda.empty_cache()
train_set = DatasetFromFolder('data/train', upscale_factor=UPSCALE_FACTOR, input_transform=transforms.ToTensor(),
target_transform=transforms.ToTensor())
val_set = DatasetFromFolder('data/val', upscale_factor=UPSCALE_FACTOR, input_transform=transforms.ToTensor(),
target_transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_set, num_workers=2, batch_size=8, shuffle=True, drop_last = True)
val_loader = DataLoader(dataset=val_set, num_workers=2, batch_size=8, shuffle=False, drop_last = True)
model = Net(upscale_factor=UPSCALE_FACTOR)
criterion = nn.MSELoss(reduction='mean')
if torch.cuda.is_available():
model = model.cuda()
criterion = criterion.cuda()
print('# parameters:', sum(param.numel() for param in model.parameters()))
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)
engine = Engine()
meter_loss = tnt.meter.AverageValueMeter()
meter_psnr = PSNRMeter()
train_loss_logger = VisdomPlotLogger('line', opts={'title': 'Train Loss'})
train_psnr_logger = VisdomPlotLogger('line', opts={'title': 'Train PSNR'})
val_loss_logger = VisdomPlotLogger('line', opts={'title': 'Val Loss'})
val_psnr_logger = VisdomPlotLogger('line', opts={'title': 'Val PSNR'})
engine.hooks['on_sample'] = on_sample
engine.hooks['on_forward'] = on_forward
engine.hooks['on_start_epoch'] = on_start_epoch
engine.hooks['on_end_epoch'] = on_end_epoch
engine.train(processor, train_loader, maxepoch=NUM_EPOCHS, optimizer=optimizer)