-
Notifications
You must be signed in to change notification settings - Fork 35
/
mrr.py
51 lines (37 loc) · 1.3 KB
/
mrr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#-*- coding: utf8
from __future__ import division, print_function
from prme import mrr
import pandas as pd
import plac
import numpy as np
def main(model, out_fpath):
store = pd.HDFStore(model)
from_ = store['from_'][0][0]
to = store['to'][0][0]
assert from_ == 0
trace_fpath = store['trace_fpath'][0][0]
XP_hk = store['XP_hk'].values
XP_ok = store['XP_ok'].values
XG_ok = store['XG_ok'].values
alpha = store['alpha'].values[0][0]
tau = store['tau'].values[0][0]
hyper2id = dict(store['hyper2id'].values)
obj2id = dict(store['obj2id'].values)
HSDs = []
dts = []
with open(trace_fpath) as trace_file:
for i, l in enumerate(trace_file):
if i < to:
continue
dt, h, s, d = l.strip().split('\t')
if h in hyper2id and s in obj2id and d in obj2id:
dts.append(float(dt))
HSDs.append([hyper2id[h], obj2id[s], obj2id[d]])
num_queries = min(10000, len(HSDs))
queries = np.random.choice(len(HSDs), size=num_queries)
dts = np.array(dts, order='C', dtype='d')
HSDs = np.array(HSDs, order='C', dtype='i4')
rrs = mrr.compute(dts, HSDs, XP_hk, XP_ok, XG_ok, alpha, tau)
np.savetxt(out_fpath, rrs)
store.close()
plac.call(main)