-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
132 lines (100 loc) · 4.93 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#this code is modified from 'https://github.com/gumusserv/CLIP-SalGan'
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
from Data_Utils import *
from get_data import *
from generator import *
from discriminator import *
from train import *
import json
def train_model(train_loader, val_loader, generator, discriminator, criterion, optimizer_G, optimizer_D, device, num_epochs=50):
record_dic = dict()
for epoch in range(num_epochs):
epoch_dic = dict()
generator.train()
discriminator.train()
running_loss_G = 0.0
running_loss_D = 0.0
for i, (images, targets, texts_embeddings) in enumerate(train_loader):
images = images.to(device)
targets = targets.to(device)
texts_embeddings = texts_embeddings.to(device)
# Train Discriminator
optimizer_D.zero_grad()
# Real samples
real_labels = torch.ones(images.size(0), 1).to(device)
outputs = discriminator(targets,texts_embeddings)
d_loss_real = criterion(outputs, real_labels)
# Fake samples
fake_targets = generator(images, texts_embeddings)
fake_labels = torch.zeros(images.size(0), 1).to(device)
outputs = discriminator(fake_targets.detach(),texts_embeddings)
d_loss_fake = criterion(outputs, fake_labels)
# total loss
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
optimizer_D.step()
# Train Generator
optimizer_G.zero_grad()
outputs = discriminator(fake_targets,texts_embeddings)
g_loss = criterion(outputs, real_labels)
g_loss.backward()
optimizer_G.step()
running_loss_G += g_loss.item()
running_loss_D += d_loss.item()
if (i + 1) % 10 == 0:
print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], '
f'Generator Loss: {running_loss_G / (i + 1)}, Discriminator Loss: {running_loss_D / (i + 1)}')
step_dic = dict()
step_dic['G LOSS'] = running_loss_G / (i + 1)
step_dic['D LOSS'] = running_loss_D / (i + 1)
epoch_dic[f"Step [{i + 1}/{len(train_loader)}]"] = step_dic
generator.eval()
discriminator.eval()
with torch.no_grad():
val_loss = 0.0
for images, targets, texts_embeddings in val_loader:
images = images.to(device)
targets = targets.to(device)
texts_embeddings = texts_embeddings.to(device)
fake_targets = generator(images, texts_embeddings)
outputs = discriminator(fake_targets,texts_embeddings)
val_loss += criterion(outputs, torch.ones(images.size(0), 1).to(device)).item()
print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: G - {g_loss.item()}, D - {d_loss.item()}, Val Loss: {val_loss / len(val_loader)}')
step_dic = dict()
step_dic["Train G Loss"] = g_loss.item()
step_dic["Train D Loss"] = d_loss.item()
step_dic["Val Loss"] = val_loss / len(val_loader)
epoch_dic["Final"] = step_dic
record_dic[epoch] = epoch_dic
with open('g2d2/loss.json', 'w') as f:
json.dump(record_dic, f)
if __name__ == '__main__':
image_directory_path = 'saliency/image'
target_directory_path = 'saliency/map'
image_paths, target_paths, text_descriptions = get_Data(image_directory_path, target_directory_path)
train_data, val_data, test_data = split_dataset(image_paths, target_paths, text_descriptions)
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
batch_size = 16
train_loader = create_dataloader(train_data, transform, batch_size=batch_size)
val_loader = create_dataloader(val_data, transform, batch_size=batch_size)
test_loader = create_dataloader(test_data, transform, batch_size=batch_size, shuffle=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#generator = Generator().to(device)
generator =Generator2().to(device)
discriminator = Discriminator2().to(device)
criterion = nn.BCELoss()
# optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.3)
# optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.3)
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.00002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.00002, betas=(0.5, 0.999))
num_epochs = 50
train_model(train_loader, val_loader, generator, discriminator, criterion, optimizer_G, optimizer_D, device, num_epochs)
torch.save(generator.state_dict(), 'g2d2/generator.pt')
torch.save(discriminator.state_dict(), 'g2d2/discriminator.pt')