-
Notifications
You must be signed in to change notification settings - Fork 2
/
load_model.py
54 lines (50 loc) · 2.02 KB
/
load_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from model import SimpleNet
from torch import optim
import torch as meg
from pathlib import Path
from collections import OrderedDict
def settings(base_lr=1e-5, pretrained=None, cuda=False):
model = SimpleNet()
optimizer = optim.Adam(model.parameters(), lr=base_lr, weight_decay=0.0001)
if cuda:
model = model.cuda()
if pretrained == 'torch_pretrained.ckp':
pretrained = Path(pretrained)
with pretrained.open("rb") as f:
if cuda:
states = meg.load(f)
else:
states = meg.load(f, map_location=meg.device('cpu'))
new_states = OrderedDict()
new_states = new_states.fromkeys(list(model.state_dict().keys()))
cnt = 0
vals = list(states.values())
for key in new_states:
new_states[key] = vals[cnt]
cnt += 1
model.load_state_dict(new_states)
model.eval()
elif pretrained == 'checkpoint.pth':
model_ckpt = meg.load(pretrained)
model.load_state_dict(model_ckpt['model'])
optimizer.load_state_dict(model_ckpt['optimizer'])
base_lr = model_ckpt['lr']
# optimizer = meg.optimizer.Adam(model.parameters(), lr=base_lr, weight_decay=0.0001)
# lr_scheduler = meg.optimizer.LRScheduler(optimizer)
lr_scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=base_lr, max_lr=1e-3, cycle_momentum=False)
# lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 30)
return model, optimizer, lr_scheduler
if __name__ == "__main__":
pretrained = Path('torch_pretrained.ckp')
model = SimpleNet()
with pretrained.open("rb") as f:
states = meg.load(f, map_location=meg.device('cpu'))
new_states = OrderedDict()
new_states = new_states.fromkeys(list(model.state_dict().keys()))
cnt = 0
vals = list(states.values())
for key in new_states:
new_states[key] = vals[cnt]
cnt += 1
model.load_state_dict(new_states)
model.eval()