-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot_loss.py
62 lines (51 loc) · 2.15 KB
/
plot_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#this code is modified from 'https://github.com/gumusserv/CLIP-SalGan'
# Reading and processing all four JSON files to prepare for the combined subplots
import json
import matplotlib.pyplot as plt
# File paths for the four JSON files
files = ['g1d1/loss.json', 'g1d2/loss.json',
'g2d1/loss.json', 'g2d2/loss.json']
# Function to extract loss data from a JSON file
def extract_losses(file_path):
with open(file_path, 'r') as file:
data = json.load(file)
step_g_losses, step_d_losses, val_losses, step_counts = [], [], [], []
step_counter = 0
epochs = list(data.keys())
for epoch in epochs:
for step in data[epoch]:
if step != 'Final':
step_g_losses.append(data[epoch][step]['G LOSS'])
step_d_losses.append(data[epoch][step]['D LOSS'])
step_counts.append(step_counter)
step_counter += 1
else:
step_g_losses.append(data[epoch][step]['Train G Loss'])
step_d_losses.append(data[epoch][step]['Train D Loss'])
step_counts.append(step_counter)
val_losses.append(data[epoch][step]['Val Loss'])
return step_g_losses, step_d_losses, val_losses, step_counts
# Extracting data from all files
loss_data = [extract_losses(file) for file in files]
# Plotting the data in subplots
fig, axs = plt.subplots(2, 2, figsize=(16, 10))
axs = axs.ravel()
for i, (step_g_losses, step_d_losses, val_losses, step_counts) in enumerate(loss_data):
axs[i].plot(step_counts, step_g_losses, label='G Loss', color='blue')
axs[i].plot(step_counts, step_d_losses, label='D Loss', color='red')
axs[i].plot([x for x in range(0, len(val_losses))], val_losses, label='Val Loss', color='green', linestyle='--')
axs[i].set_xlim(0, 50)
axs[i].set_xlabel('Epoch')
axs[i].set_ylabel('Loss')
if i == 0:
axs[i].set_title(f'Losses for G1D1')
elif i == 1:
axs[i].set_title(f'Losses for G1D2')
elif i == 2:
axs[i].set_title(f'Losses for G2D1')
else:
axs[i].set_title(f'Losses for G2D2')
axs[i].legend()
axs[i].grid(True)
plt.tight_layout()
plt.show()