-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_checkpoints.py
39 lines (27 loc) · 1.14 KB
/
run_checkpoints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#script just to enable a subjective eval of input data using a predeterined checkpoint
import os
import argparse
import numpy as np
import torch
import pytorch_lightning as pl
import sys
from train import XUMXManager
parser = argparse.ArgumentParser()
import yaml
#paths for config files:
epoch_name = 'epoch=7-step=36575.ckpt'
epoch_dir = 'exp_outputs'
serialized_model_path = os.path.join('exp_outputs' , 'serialized_model')
train_info_data_path = os.path.join('exp_outputs' , 'train_data_info_dict')
epoch_path = os.path.join(epoch_dir, epoch_name)
state_dict = torch.load(epoch_path, map_location=torch.device('cpu'))
model = torch.load(serialized_model_path, map_location=torch.device('cpu'))
train_info = torch.load(train_info_data_path, map_location=torch.device('cpu'))
#import pdb
#pdb.set_trace()
modified_state_dict = {(key[6:] if key[0:5] == 'model' else key) : val
for key, val in state_dict['state_dict'].items()}
model['state_dict'] = modified_state_dict
model.update(train_info)
model['state_dict'].pop('loss_func.transform.0.window', None)
torch.save(model, os.path.join(epoch_dir, 'runnable-{}'.format(epoch_name)))