-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathredthread_run.py
147 lines (123 loc) · 5.36 KB
/
redthread_run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import pickle as pkl
import numpy as np
from redthread import RedThread
from argparse import ArgumentParser
import networkx as nx
def get_args():
parser = ArgumentParser()
parser.add_argument("-data","--data_file",default="./data/sample_data/sampled_data_features.pkl", \
help="path to the data pickle files containing the freature matrix")
parser.add_argument("-label","--label_file", default="./data/sample_data/sampled_data_labels.pkl", \
help='path to the labels of the data points as a pickle file')
parser.add_argument("-modality", "--feature_name_file", default="./data/sample_data/sampled_data_feature_names.pkl", \
help="path to the feature names of the data as as pickle file")
parser.add_argument("--num_seeds",default=1, help='number of seed data points to average final precision, recall over')
parser.add_argument("--data_folder", default="./data/sample_data/", \
help="path to the folder containing the different feature files")
parser.add_argument("--budget", default=10, type=int, \
help='Number of times the model can query the user')
parser.add_argument("--build_graph", default=False, type=bool, \
help='True if you want to build the data graph from scratch and False to use the pre-built graph')
args = parser.parse_args()
return args
def iterative_labelling(seed, data, budget, rt, rt_graph, feature_map, evidence_nbrs, ad_nbrs):
query_counter = 0 # initializing the query counter "b" in the algorithm
# label_hash = rt.label_hash
# rt_graph = rt.get_graph()
picked_nodes = []
precision = 0.
recall = 0.
num_relevant = rt.num_positive_labels()
while query_counter < budget:
print("Remaining Number of queries : " + str(budget-query_counter))
picked_node = rt.infer_redthread(rt_graph, ad_nbrs, evidence_nbrs, feature_map) # pick a data point
# if rt.near_duplicate(picked_node): # check if the picked node is a near dupliate of the positively labelled nodes so far
# continue
if picked_node not in list(rt.label_hash.keys()):
picked_node_label = rt.oracle(picked_node) # querying the user for the label of the picked node
rt.update_label_hash(picked_node, picked_node_label) # update the label hash based on the oracle output
query_counter += 1
if picked_node_label == 1:
precision += 1
recall += 1
else:
picked_node_label = rt.label_hash[picked_node]
print("updating redthread")
rt.update_redthread(rt_graph, picked_node, picked_node_label, feature_map, ad_nbrs, evidence_nbrs)
picked_nodes.append(picked_node)
#print(label_hash)
precision /= budget
recall /= num_relevant
return precision, recall
# def word_to_id(words):
# # assigns an id to each of the feature words
# word_id_map = {}
# counter = 1
# for word in set(words):
# word_id_map[word] = -counter
# counter += 1
# return word_id_map
# def get_word_id(words, word_id_map):
# # stores the id of the word in the feature_map
# word_ids = []
# for word in set(words):
# word_ids.append(word_id_map[word])
# return word_ids
def extract_info(args):
data_file = args.data_file
label_file = args.label_file
all_feature_file = args.feature_name_file
data = pkl.load(open(data_file, "rb"))
labels = pkl.load(open(label_file, "rb"))
feature_names = pkl.load(open(all_feature_file, "rb"))
feature_map = {}
feature_map["desc_uni"] = pkl.load(open(args.data_folder + "desc_feature_names_uni.pkl","rb"))
feature_map["desc_bi"] = pkl.load(open(args.data_folder + "desc_feature_names_bi.pkl","rb"))
feature_map["title_uni"] = pkl.load(open(args.data_folder + "title_feature_names_uni.pkl","rb"))
feature_map["title_bi"] = pkl.load(open(args.data_folder + "title_feature_names_bi.pkl","rb"))
feature_map["loc_uni"] = pkl.load(open(args.data_folder + "loc_feature_names.pkl","rb"))
return data, labels, feature_names, feature_map
if __name__ == "__main__":
# get command line arguments
args = get_args()
global rt_graph
# get the data and labels and feature names
data, labels, feature_names, feature_map = extract_info(args)
# choosing random seed nodes for now
total_prec = 0.
total_rec = 0.
seeds = np.random.choice(len(data), size=args.num_seeds, replace=False)
for seed in seeds:
# create a RedThread object
rt = RedThread(labels, seed, feature_names, feature_map, args.build_graph)
rt_graph, evidence_nbrs, ad_nbrs = rt.build_graph(data, args.build_graph, feature_names)
rt.initialize_q(rt_graph, ad_nbrs, evidence_nbrs, seed, feature_map)
rt.initialize_shell(feature_map, evidence_nbrs)
# rt_graph = rt.graph
# nx.write_gpickle(rt_graph, "models/redthread_graph.gpkl")
# pkl.dump(rt.neighbors, open("models/redthread_graph_node_neighbors.pkl","wb"))
precision, recall = iterative_labelling(seed, data, args.budget, rt, rt_graph, feature_map, evidence_nbrs, ad_nbrs)
f1_score = 2 * precision * recall / (precision + recall)
print("Precision = " + str(precision))
print("Recall = " + str(recall))
print("F1 score = " + str(f1_score))
total_prec += precision
total_rec += recall
total_prec /= len(seeds)
total_rec /= len(seeds)
total_f1 = 2 * total_prec * total_rec / (total_prec + total_rec)
print("Total precision = " + str(total_prec))
print("Total recall = " + str(total_rec))
print("Total f1 score = " + str(total_f1))
print("success")
'''
Precision = 0.6
Recall = 0.024
F1 score = 0.04615384615384615
Precision = 0.65
Recall = 0.026
F1 score = 0.04999999999999999
Precision = 0.7
Recall = 0.028
F1 score = 0.05384615384615385
'''