-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
executable file
·71 lines (51 loc) · 1.87 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# Module that contains the pretrained ResNet50, used to extract model's features
import numpy as np
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable as V
from config import *
def read_img(filename):
img = get_image(filename)
transform = transforms.Compose([transforms.Resize([224, 224]),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
x = transform(img)
x = x.unsqueeze(0)
return x
# 2D CNN encoder using ResNet-50 pretrained
class VidaResNet(nn.Module):
def __init__(self, in_size=1, out_size=1):
"""Load the pretrained ResNet-50 and replace top fc layer."""
super(VidaResNet, self).__init__()
resnet = models.resnet50(pretrained=True)
for param in resnet.parameters():
param.requires_grad = False
modules = list(resnet.children())[:-1] # delete the last fc layer.
self.resnet = nn.Sequential(*modules)
self.eval()
def forward(self, input):
# ResNet CNN
with torch.no_grad():
x = self.resnet(input) # ResNet
x = x.view(x.size(0), -1) # flatten output of conv
return x
def predict(self, filename):
stim = read_img(filename)
with torch.no_grad():
if torch.cuda.is_available():
stim = V(stim.cuda())
else:
stim = V(stim)
out = self.resnet(stim)
return out
# main for testing
if __name__ == '__main__':
filename = sys.argv[1]
model = VidaResNet().cuda()
x = model.predict(filename)
x = x.cpu().numpy()
x = np.reshape(x, (x.size, 1))
name_out = os.path.splitext(filename)[0]
np.savetxt(name_out + '.txt', x)