-
Notifications
You must be signed in to change notification settings - Fork 5
/
sst_train.py
156 lines (132 loc) · 7.2 KB
/
sst_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import time
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torch.nn.functional as F
from wsi.bin.image_producer import ImageDataset
from fusionnet_generator import FusionGenerator
from dcgan_discriminator import Discriminator
from resnet_classifier import TranResnet34
import pytorch_ssim
import pytorch_dscsi
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
def freeze_model(model):
model.eval()
for params in model.parameters():
params.requires_grad = False
def train(epochs = 5, batch_size = 4, learning_rate = 0.0002, sample_interval = 1):
time_now = time.time()
G = FusionGenerator(1, 3, 16).cuda()
D = Discriminator(3, [32,64,128, 256, 512, 1024], 1).cuda()
R = TranResnet34().cuda()
R.load_state_dict(torch.load('./TranResnet34/save_models/TranResnet34_params.pkl'))
# OR: R.load_state_dict(torch.load('./TranResnet34/save_models/best.ckpt')['state_dict'])
freeze_model(R)
criterion_fp = nn.KLDivLoss().cuda()
criterion_gan = nn.BCELoss().cuda()
criterion_reco = nn.MSELoss().cuda()
ssim_loss = pytorch_ssim.SSIM()
dscsi_loss = pytorch_dscsi.COLOR_DSCSI(7,1)
betas = (0.5, 0.999)
# Optimizers
G_optimizer = optim.Adam(G.parameters(), lr=learning_rate, betas=betas)
D_optimizer = optim.Adam(D.parameters(), lr=learning_rate, betas=betas)
valid = Variable(torch.ones(batch_size * 2).cuda())
fake = Variable(torch.zeros(batch_size * 2).cuda())
# Change paths below to where you store your data
LOSS_DIR = "./losses"
LOSS_SSIM_FILE = os.path.join(LOSS_DIR, "sst_loss.txt")
with open(LOSS_SSIM_FILE, "w") as lossSSIM:
for epoch in range(epochs):
dataset_tumor_train = ImageDataset('./wsi/patches/tumor_train','./wsi/jsons/train',normalize=True)
dataloader_tumor = DataLoader(dataset_tumor_train, batch_size=batch_size, num_workers=2)
dataset_normal_train = ImageDataset('./wsi/patches/normal_train','./wsi/jsons/train',normalize=True)
dataloader_normal = DataLoader(dataset_normal_train, batch_size=batch_size, num_workers=2)
steps1 = len(dataloader_tumor)-1 # consider list.txt
steps2 = len(dataloader_normal)-1
steps = min(steps1,steps2)
batch_size = dataloader_tumor.batch_size
dataiter_tumor = iter(dataloader_tumor)
dataiter_normal = iter(dataloader_normal)
D_losses = []
G_losses = []
correct = 0
total = 0
for step in range(steps):
# image data and labels
data_tumor, target_tumor, data_tumor_gray = next(dataiter_tumor)
data_tumor = Variable(data_tumor.cuda())
target_tumor = Variable(target_tumor.cuda())
data_tumor_gray = Variable(data_tumor_gray.cuda())
data_normal, target_normal, data_normal_gray = next(dataiter_normal)
data_normal = Variable(data_normal.cuda())
target_normal = Variable(target_normal.cuda())
data_normal_gray = Variable(data_normal_gray.cuda())
idx_rand = Variable(torch.randperm(batch_size * 2).cuda())
data = torch.cat([data_tumor, data_normal])[idx_rand]
data_gray = torch.cat([data_tumor_gray, data_normal_gray])[idx_rand]
target = torch.cat([target_tumor, target_normal])[idx_rand]
# Train discriminator with real data
D_valid_decision = D(data).squeeze()
D_valid_loss = criterion_gan(D_valid_decision, valid)
# Train discriminator with fake data
data_gene = G(data_gray)
D_fake_decision = D(data_gene).squeeze()
D_fake_loss = criterion_gan(D_fake_decision, fake)
# Back propagation
D_loss = D_valid_loss + D_fake_loss
lossSSIM.write("%5d,%5d,%10lf," % (epoch+1, (epoch+1)*steps+step,D_loss.data))
D_optimizer.zero_grad()
D_loss.backward()
D_optimizer.step()
# Train generator
data_gene = G(data_gray)
D_fake_decision = D(data_gene).squeeze()
# Computer feature preservation loss
valid_features, _ = R(data)
valid_features = Variable(valid_features.cuda(), requires_grad=False)
fake_features, result = R(data_gene)
fake_features_lsm = F.log_softmax(fake_features,1)
valid_features_sm = F.softmax(valid_features,1)
fp_loss = criterion_fp(fake_features_lsm, valid_features_sm)
# Compute image reconstruction loss
reco_loss = criterion_reco(data_gene, data)
ssim = ssim_loss(data,data_gene)
dscsi = dscsi_loss(data, data_gene)
# Back propagation
G_loss = 0.2 * criterion_gan(D_fake_decision, valid) + 0.4 * fp_loss + 0.4 * reco_loss
lossSSIM.write("%10lf,%10lf,%10lf,%10lf," % (
G_loss.data, reco_loss.data, 1-ssim.data, dscsi.data))
D_optimizer.zero_grad()
G_optimizer.zero_grad()
G_loss.backward()
G_optimizer.step()
D_losses.append(D_loss.item())
G_losses.append(G_loss.item())
# Test the result
_, predicted = result.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
lossSSIM.write("%10lf\n" % (correct / total))
time_spent = time.time() - time_now
if (step + 1) % 20 == 0:
print("[Epoch %d/%d], [Step %d/%d], [D_loss: %.4f], [G_loss: %.4f],[FP_loss: %.4f], [Reco_loss: %.4f], [Accu:%3d%%], [RunTime:%.4f]"
% (epoch + 1, epochs, step + 1, steps, D_loss.item(), G_loss.item(),fp_loss.item(), reco_loss.item(), 100. * correct / total, time_spent))
D_avg_loss = torch.mean(torch.FloatTensor(D_losses))
G_avg_loss = torch.mean(torch.FloatTensor(G_losses))
print("[Epoch %d/%d], [Step %d/%d], [D_avg_loss: %.4f], [G_avg_loss: %.4f], [Accu:%3d%%], [RunTime:%.4f]"
% (epoch + 1, epochs, step + 1, steps, D_avg_loss, G_avg_loss, 100.*correct/total, time_spent))
if (epoch+1) % sample_interval == 0:
torch.save(D.state_dict(), './save_models/D_params_sst_%d.pkl'%(epoch+1))
torch.save(G.state_dict(), './save_models/G_params_sst_%d.pkl'%(epoch+1))
# torch.save(D.state_dict(), './save_models/D_params_sst.pkl')
# torch.save(G.state_dict(), './save_models/G_params_sst.pkl')
print("FINAL:[Epoch %d/%d], [Step %d/%d], [D_avg_loss: %.4f], [G_avg_loss: %.4f], [Accu:%3d%%], [RunTime:%.4f]"
% (epoch + 1, epochs, step + 1, steps, D_avg_loss, G_avg_loss, 100.*correct/total, time_spent))
print("Model saved.")
if __name__ == '__main__':
train()