-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
69 lines (53 loc) · 1.75 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import numpy as np
from datetime import datetime
import cv2
import torch
import torch.utils.data as data
import torch.optim as optim
from torchvision import datasets, transforms
from model import VAE
from torch.utils.tensorboard import SummaryWriter
import my_mnist as mn
# mnist directory
file_dir = 'C:\\workspace\\dataset\\MNIST\\'
cv2.namedWindow('img_sample', cv2.WINDOW_NORMAL)
cv2.namedWindow('out_sample', cv2.WINDOW_NORMAL)
def eval():
# データ読み込み & 変形
images = mn.load_mnist_img(file_dir, mode='eval')
labels = mn.load_mnist_labels(file_dir, mode='eval')
onehot_labels = mn.mnist_labels_to_onehot(labels)
flat_images = mn.mnist_images_to_vector(images)
images = torch.from_numpy(flat_images)
labels = torch.from_numpy(onehot_labels)
# モデル初期設定
Latent_num = 20
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = VAE(Latent_num)
model.load_state_dict(torch.load('saved_model/VAE2019_09_22_01_30_17.pth'))
model.to(device)
print(model)
model.eval()
for i in range(labels.size()[0]):
img = images[i:i+1].to(device)
img_sample = tensor_to_display_image(img)
print(img)
with torch.no_grad():
out, z = model(img)
out_sample = tensor_to_display_image(out)
cv2.imshow('img_sample', img_sample)
cv2.imshow('out_sample', out_sample)
key = cv2.waitKey(500)
if key == 27:
break
def tensor_to_display_image(array):
array = array.to('cpu').detach().numpy()[0]
array = array.reshape(28, 28)
array = array * 255
array[array > 255] = 255
array[array < 0] = 0
array = array.astype('uint8')
return array
if __name__ == '__main__':
eval()