-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
117 lines (94 loc) · 3.98 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""
import os
from collections import OrderedDict
import data
from options.test_options import TestOptions
# from models.pix2pix_model import Pix2PixModel
from util.visualizer import Visualizer
from util import html
import torch
opt = TestOptions().parse()
if opt.dual:
from models.pix2pix_dualmodel import Pix2PixModel
elif opt.dual_segspade:
from models.pix2pix_dual_segspademodel import Pix2PixModel
elif opt.box_unpair:
from models.pix2pix_dualunpair import Pix2PixModel
else:
from models.pix2pix_model import Pix2PixModel
dataloader = data.create_dataloader(opt)
model = Pix2PixModel(opt)
model.eval()
visualizer = Visualizer(opt)
# create a webpage that summarizes the all results
web_dir = os.path.join(opt.results_dir, opt.name,
'%s_%s' % (opt.phase, opt.which_epoch))
webpage = html.HTML(web_dir,
'Experiment = %s, Phase = %s, Epoch = %s' %
(opt.name, opt.phase, opt.which_epoch))
# test
# for i, data_i in enumerate(dataloader):
# if i * opt.batchSize >= opt.how_many:
# break
# # print(data_i)
# generated = model(data_i, mode='inference')
# img_path = data_i['path']
# for b in range(generated.shape[0]):
# print('process image... %s' % img_path[b])
# print(data_i['label'][b])
# insset = set([])
# data_np = data_i['label'][b].data.cpu().numpy()[0]
# print(data_np.shape)
# for row in range(255):
# for column in range(255):
# insset.add(data_np[row,column])
# print(insset)
# print((data_i['label'][b] == 171).float()*data_i['label'][b])
# print((data_i['label'][b] == 171).float()*data_i['label'][b].float())
# #for ins in insset:
# # print(ins)
# visuals = OrderedDict([('input_label', data_i['label'][b].float()*((data_i['label'][b]==182).float())),
# ('synthesized_image', generated[b])])
# visualizer.save_images(webpage, visuals, img_path[b:b + 1])
# if i == 1:
# break
for i, data_i in enumerate(dataloader):
if i * opt.batchSize >= opt.how_many:
break
img_path = data_i['path']
# print(img_path)
generated = model(data_i, mode='inference')
for b in range(generated.shape[0]):
print('process image... %s' % img_path[b])
visuals = OrderedDict([('input_label', data_i['label'][b][0:35]),
('synthesized_image', generated[b])])
if opt.retrival_memory:
visuals = OrderedDict([('input_label', data_i['retrival_label_list'][b][0:35]),
('synthesized_image', generated[b])])
visualizer.save_images(webpage, visuals, img_path[b:b + 1])
webpage.save()
# for i, data_i in enumerate(dataloader):
# if i * opt.batchSize >= opt.how_many:
# break
# img_path = data_i['path']
# # print(img_path)
# # frankfurt_000000_002963_leftImg8bit
# # if '23769' in img_path[0]:
# # if '3357' in img_path[0]:
# if 'ADE_val_00000124' in img_path[0]:
# print(img_path)
# generated = model(data_i, mode='inference')
# for b in range(generated.shape[0]):
# print('process image... %s' % img_path[b])
# visuals = OrderedDict([('input_label', data_i['label'][b][0:35]),
# ('synthesized_image', generated[b])])
# if opt.retrival_memory:
# visuals = OrderedDict([('input_label', data_i['retrival_label_list'][b][0:35]),
# ('synthesized_image', generated[b])])
# visualizer.save_images(webpage, visuals, img_path[b:b + 1])
# else:
# continue
webpage.save()