-
Notifications
You must be signed in to change notification settings - Fork 2
/
data.py
135 lines (113 loc) · 4.79 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import re
import string
import unicodedata
from datasets import load_dataset
def normalize_answer(s):
"""Normalize answer. (Directly copied from ORQA codebase)"""
s = unicodedata.normalize("NFD", s)
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def load(args):
"""Load dataset"""
if os.path.isdir(args.dataset_name_path):
raise ValueError("Dataset path currently not supported.")
if args.dataset_name_path == "natural_questions":
return load_nq(args)
elif args.dataset_name_path == "web_questions":
return load_wq(args)
elif args.dataset_name_path == "dummy":
return load_dummy(args)
else:
raise ValueError("Invalid dataset name or path")
def load_nq(args):
"""Load NaturalQuestions."""
def filter_fn(example):
"""Remove answers having length more than 5."""
for short_answer in example['annotations.short_answers']:
if len(short_answer) != 0:
for i in range(len(short_answer['text'])):
if short_answer['end_token'][i] - short_answer['start_token'][i] <= args.max_answer_tokens:
return True
return False
def map_fn(example):
"""Unify dataset structures."""
return {
"question": example["question.text"],
"answers": [answer["text"] for answer in example["annotations.short_answers"]]
}
dataset = load_dataset(args.dataset_name_path, cache_dir=os.path.abspath(args.dataset_cache_dir))
# Remove unused columns and flatten structure.
training_dev_dataset = dataset['train'].train_test_split(test_size=args.dev_ratio, shuffle=False)
training_dataset = training_dev_dataset['train'].remove_columns(['id', 'document']).flatten()
dev_dataset = training_dev_dataset['test'].remove_columns(['id', 'document']).flatten()
eval_dataset = dataset['validation'].remove_columns(['id', 'document']).flatten()
# Perform filtering and mapping
filtered_training_dataset = training_dataset.filter(filter_fn).map(map_fn)
filtered_dev_dataset = dev_dataset.filter(filter_fn).map(map_fn)
filtered_eval_dataset = eval_dataset.filter(filter_fn).map(map_fn)
# An exmaple of each dataset should contain the following columns:
# example["question"]
# example["answers"][num_answers]
return filtered_training_dataset, filtered_dev_dataset, filtered_eval_dataset
def load_wq(args):
"""Load WebQuestions(WQ)."""
dataset = load_dataset(args.dataset_name_path, cache_dir=os.path.abspath(args.dataset_cache_dir))
# Remove unused columns and flatten structure.
training_dev_dataset = dataset['train'].train_test_split(test_size=args.dev_ratio, shuffle=False)
training_dataset = training_dev_dataset['train'].remove_columns(['url'])
dev_dataset = training_dev_dataset['test'].remove_columns(['url'])
eval_dataset = dataset['test'].remove_columns(['url'])
# No need to filter
filtered_training_dataset = training_dataset
filtered_dev_dataset = dev_dataset
filtered_eval_dataset = eval_dataset
# An exmaple of each dataset should contain the following columns:
# example["question"]
# example["answers"][num_answers]
return filtered_training_dataset, filtered_dev_dataset, filtered_eval_dataset
def load_dummy(args):
dataset = [
{
"question": "What is the previous name of Meta Platform, Inc.?",
"answers": [
"facebook, inc.",
],
},
{
"question": "Who is the pioneer in modern computer science?",
"answers": [
"alan mathison turing",
],
},
]
return dataset, dataset, dataset
class DataCollator(object):
def __init__(self, args, tokenizer):
self.args = args
self.tokenizer = tokenizer
def __call__(self, batch):
example = batch[0]
question = example["question"]
answer_texts = []
for answer in example["answers"]:
answer_texts += [answer] if isinstance(answer, str) else answer
answer_texts = list(set(answer_texts))
if len(answer_texts) != 0:
answer_ids = self.tokenizer(
answer_texts,
add_special_tokens=False,
return_token_type_ids=False,
return_attention_mask=False,
).input_ids
else:
answer_ids = [[-1]]
return question, answer_texts, answer_ids