-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathinference.py
95 lines (74 loc) · 2.33 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/usr/bin/env python
"""
inference.py
"""
import os
import sys
import torch
import argparse
import numpy as np
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from basenet.helpers import set_seeds, to_numpy
from basenet.text.data import RaggedDataset, text_collate_fn
from ulmfit import TextClassifier
# --
# CLI
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--lm-weights-path', type=str)
parser.add_argument('--X', type=str)
parser.add_argument('--outpath', type=str, default='preds')
parser.add_argument('--seed', type=int, default=123)
return parser.parse_args()
# --
# Run
if __name__ == "__main__":
# --
# Params
bptt, emb_sz, n_hid, n_layers, batch_size = 70, 400, 1150, 3, 48
max_seq = 20 * 70
pad_token = 1
args = parse_args()
set_seeds(args.seed)
# --
# IO
X = np.load(args.X)
# Sort validation data by length, longest to shortest, for efficiency
o = np.argsort([len(xx) for xx in X])[::-1]
X = X[o]
dataloaders = {
"inference" : DataLoader(
dataset=RaggedDataset(X, y=torch.zeros(len(X)).long() - 1),
batch_size=batch_size,
collate_fn=text_collate_fn,
shuffle=False,
num_workers=1,
)
}
# --
# Define model
lm_weights = torch.load(args.lm_weights_path)
n_tok = lm_weights['encoder.encoder.weight'].shape[0] # Shape of encoding matrix
n_class = lm_weights[list(lm_weights.keys())[-1]].shape[0] # Shape of last weights
classifier = TextClassifier(
bptt = bptt,
max_seq = max_seq,
n_class = n_class,
n_tok = n_tok,
emb_sz = emb_sz,
n_hid = n_hid,
n_layers = n_layers,
pad_token = pad_token,
head_layers = [emb_sz * 3, 50, n_class],
head_drops = [0.0, 0.0],
predict_only = True
).to('cuda')
classifier.verbose = True
classifier.load_state_dict(lm_weights, strict=True)
_ = classifier.eval()
preds, _ = classifier.predict(dataloaders, mode='inference')
# return to correct order
preds = to_numpy(preds)[np.argsort(o)]
np.savetxt(args.outpath, preds, fmt='%.10f', delimiter='\t')