-
Notifications
You must be signed in to change notification settings - Fork 0
/
build_dataset2_baseline.py
131 lines (103 loc) · 4.23 KB
/
build_dataset2_baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import random
from tqdm import tqdm
import spacy
from transformers import BertTokenizer
from ner_processing import custom_anonymize_text
from utils import save_row_to_jsonl_file, empty_json_file, load_jsonl_file
"""
NOTE: install the following spaCy model before running this script:
python -m spacy download en_core_web_trf
"""
# Load spacy model
nlp_trf = spacy.load("en_core_web_trf")
# Initialize path and name of output JSON-L file
output_file = "/content/shared_data/dataset_2_2_pair_sentences.jsonl"
# Initialize a JSONL file for the dataset
empty_json_file(output_file)
# Load the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
# Load all passage records from the JSON file
dataset2_raw = load_jsonl_file("/content/shared_data/dataset2_raw_dec_11.jsonl")
WITH_ANONYMIZATION = True
SEED = 42
# Initialize unique ID counter for datapoints
datapoint_id = 0
# Initialize counters
exceeded_token_limit = 0
continue_class_counter = 0
not_continue_class_counter = 0
# Initialize dataset
dataset = []
# Initialize ignored annotators
ignore_list = ["IE-Reyes"]
# Set seed for reproducibility
random.seed(SEED)
for passage in tqdm(dataset2_raw, desc=f"Processing {len(dataset2_raw)} passages"):
if passage["metadata"]["annotator"] in ignore_list:
continue
# Get sentences data
sentences = passage["text"]
for idx, sentence in enumerate(sentences):
# Proceed only if there's a next sentence
if idx < len(sentences) - 1:
# Modify role to 'inside' if it's 'beginning'
if sentences[idx + 1]["role"] == "beginning":
sentences[idx + 1]["role"] = "inside"
next_sentence = sentences[idx + 1]
# Form pair with padding and check max length
pair = "[CLS] " + sentence["sentence"] + " [SEP] " + next_sentence["sentence"] + " [SEP]"
if len(tokenizer.tokenize(pair)) > 512:
exceeded_token_limit += 1
continue
if WITH_ANONYMIZATION:
# Anonymize text
more_entities = ["COVID-19", "COVID", "Army", "WeCanDoThis.HHS.gov", "HIV"]
pair = custom_anonymize_text(pair, nlp_trf,
["PERSON", "NORP", "FAC", "ORG", "GPE", "LOC", "PRODUCT", "EVENT", "LAW", "DATE",
"TIME", "MONEY", "QUANTITY"])
for entity in more_entities:
pair = pair.replace(entity, "[ENTITY]")
# Assign labels based on roles
label = None
if sentence["role"] == "inside" and next_sentence["role"] == "inside":
label = "continue"
continue_class_counter += 1
elif sentence["role"] == "outside" and next_sentence["role"] == "inside":
label = "not_continue"
not_continue_class_counter += 1
elif sentence["role"] == "inside" and next_sentence["role"] == "outside":
label = "not_continue"
not_continue_class_counter += 1
# Only create and save datapoint if label is assigned
if label:
datapoint_id += 1
datapoint = {
"id": datapoint_id,
"passage_id": passage["id"],
"text": pair,
"label": label,
"annotator": passage["metadata"]["annotator"],
}
# print(f"Processing datapoint ({datapoint_id}): {datapoint['text']}")
dataset.append(datapoint)
# prune "continue" class for representation balance
dataset_continue = []
dataset_not_continue = []
for datapoint in tqdm(dataset, desc="Pruning 'continue' class"):
if datapoint["label"] == "continue":
dataset_continue.append(datapoint)
else:
dataset_not_continue.append(datapoint)
# Shuffle the dataset with label "continue" before pruning to create diversity
random.shuffle(dataset_continue)
# Prune 13% of the dataset with label "continue"
dataset_continue = dataset_continue[:int(len(dataset_continue) * 0.87)]
dataset = dataset_continue + dataset_not_continue
# Save and print the datapoint
for datapoint in tqdm(dataset, desc="Saving dataset"):
save_row_to_jsonl_file(datapoint, output_file)
print("\nClass distribution:")
print(f"• Continue: {len(dataset_continue)} (before prune {continue_class_counter})")
print(f"• Not continue: {not_continue_class_counter}")
print(f"\nExceeded token limit: {exceeded_token_limit}")
print(f"Ignored annotators: {ignore_list}")