-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpreprocess.py
91 lines (83 loc) · 4.07 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import argparse
from multiprocessing import cpu_count
from preprocess_utils.convert_csqa import convert_to_entailment
from preprocess_utils.convert_obqa import convert_to_obqa_statement
from preprocess_utils.conceptnet import extract_english, construct_graph
from preprocess_utils.grounding import create_matcher_patterns, ground
from preprocess_utils.graph import generate_adj_data_from_grounded_concepts__use_LM
input_paths = {
'csqa': {
'train': './data/csqa/train_rand_split.jsonl',
'dev': './data/csqa/dev_rand_split.jsonl',
'test': './data/csqa/test_rand_split_no_answers.jsonl',
},
'obqa': {
'train': './data/obqa/OpenBookQA-V1-Sep2018/Data/Additional/train_complete.jsonl',
'dev': './data/obqa/OpenBookQA-V1-Sep2018/Data/Additional/dev_complete.jsonl',
'test': './data/obqa/OpenBookQA-V1-Sep2018/Data/Additional/test_complete.jsonl',
},
'cpnet': {
'csv': './data/cpnet/conceptnet-assertions-5.6.0.csv',
},
}
output_paths = {
'cpnet': {
'csv': './data/cpnet/conceptnet.en.csv',
'vocab': './data/cpnet/concept.txt',
'patterns': './data/cpnet/matcher_patterns.json',
'unpruned-graph': './data/cpnet/conceptnet.en.unpruned.graph',
'pruned-graph': './data/cpnet/conceptnet.en.pruned.graph',
},
'csqa': {
'statement': {
'train': './data/csqa/statement/train.statement.jsonl',
'dev': './data/csqa/statement/dev.statement.jsonl',
'test': './data/csqa/statement/test.statement.jsonl',
},
'grounded': {
'train': './data/csqa/grounded/train.grounded.jsonl',
'dev': './data/csqa/grounded/dev.grounded.jsonl',
'test': './data/csqa/grounded/test.grounded.jsonl',
},
'graph': {
'adj-train': './data/csqa/graph/train.graph.adj.pk',
'adj-dev': './data/csqa/graph/dev.graph.adj.pk',
'adj-test': './data/csqa/graph/test.graph.adj.pk',
},
},
'obqa': {
'statement': {
'train': './data/obqa/statement/train.statement.jsonl',
'dev': './data/obqa/statement/dev.statement.jsonl',
'test': './data/obqa/statement/test.statement.jsonl',
'train-fairseq': './data/obqa/fairseq/official/train.jsonl',
'dev-fairseq': './data/obqa/fairseq/official/valid.jsonl',
'test-fairseq': './data/obqa/fairseq/official/test.jsonl',
},
'grounded': {
'train': './data/obqa/grounded/train.grounded.jsonl',
'dev': './data/obqa/grounded/dev.grounded.jsonl',
'test': './data/obqa/grounded/test.grounded.jsonl',
},
'graph': {
'adj-train': './data/obqa/graph/train.graph.adj.pk',
'adj-dev': './data/obqa/graph/dev.graph.adj.pk',
'adj-test': './data/obqa/graph/test.graph.adj.pk',
},
},
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--run', default=['common', 'csqa', 'obqa'], choices=['common', 'csqa', 'obqa'], nargs='+')
parser.add_argument('-p', '--nprocs', type=int, default=cpu_count(), help='number of processes to use')
parser.add_argument('--debug', action='store_true', help='enable debug mode')
args = parser.parse_args()
if args.debug:
raise NotImplementedError()
generate_adj_data_from_grounded_concepts__use_LM('./data/obqa/re_grounded/test.re_grounded.jsonl','data/cpnet/conceptnet.en.pruned.graph','data/cpnet/concept.txt','./data/obqa/re_graph/test.re_graph.adj.pk',args.nprocs)
generate_adj_data_from_grounded_concepts__use_LM('./data/obqa/re_grounded/train.re_grounded.jsonl','data/cpnet/conceptnet.en.pruned.graph','data/cpnet/concept.txt','./data/obqa/re_graph/train.re_graph.adj.pk',args.nprocs)
generate_adj_data_from_grounded_concepts__use_LM('./data/obqa/re_grounded/dev.re_grounded.jsonl','data/cpnet/conceptnet.en.pruned.graph','data/cpnet/concept.txt','./data/obqa/re_graph/dev.re_graph.adj.pk',args.nprocs)
print('Successfully run {}'.format(' '.join(args.run)))
if __name__ == '__main__':
main()
# pass