Skip to content

Commit d228b50

Browse files
committed
evaluation script and prediction for BERT models
1 parent 3c58176 commit d228b50

File tree

4 files changed

+371
-0
lines changed

4 files changed

+371
-0
lines changed

README.md

+8
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ TODO: TO_COME
3737
The source code of training and inference of the models presented in the paper
3838
are included in the folder `models/` (`models/BERT` and `models/Mixtral`).
3939

40+
## Installation
41+
42+
You need `python3`, `pip3`.
43+
44+
```bash
45+
pip3 install -r requirements.txt
46+
```
47+
4048
## Citation
4149

4250
If you use this corpus or the source code of this repository, please cite the

evaluation/eval

+200
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
#!/usr/bin/env python3
2+
3+
import argparse
4+
5+
parser = argparse.ArgumentParser(
6+
description='Tool used to evaluate the performances of a model ' \
7+
'predictions relative to the dataset annotations.'
8+
)
9+
parser.add_argument(
10+
'--reference_dataset', required=True, type=str,
11+
help='Path to the original reference dataset folder.'
12+
)
13+
parser.add_argument(
14+
'--prediction', required=True, type=argparse.FileType('r'),
15+
help='Predictions from the models in JSON format.'
16+
)
17+
parser.add_argument(
18+
'--subset', required=True, type=str,
19+
choices=['validation', 'test'],
20+
help='Subset to evaluate.'
21+
)
22+
23+
args = parser.parse_args()
24+
25+
################################################################################
26+
27+
import os
28+
import json
29+
import datetime
30+
import collections
31+
32+
import numpy as np
33+
import pandas as pd
34+
35+
################################################################################
36+
# PREPARING THE DATASET #
37+
################################################################################
38+
39+
reference_file = os.path.join(args.reference_dataset, f"{args.subset}.json")
40+
if not os.path.exists(reference_file):
41+
raise FileNotFoundError(
42+
f"The dataset reference for {args.subset=} cannot be found at " \
43+
"{reference_file}."
44+
)
45+
46+
reference = pd.read_json(reference_file, orient='index')
47+
prediction = pd.read_json(args.prediction, orient='index')
48+
49+
reference_dialogues = set(reference.index)
50+
prediction_dialogues = set(prediction.index)
51+
52+
allowed_classes = set(reference.iloc[0].classes.keys())
53+
54+
if reference_dialogues != prediction_dialogues:
55+
raise ValueError(
56+
"Reference dialogues are not the same as prediction dialogues.\n" \
57+
f" - Total reference dialogues: {len(reference_dialogues)}\n" \
58+
f" - Total prediction dialogues: {len(prediction_dialogues)}\n" \
59+
f" - In common dialogues : {len(reference_dialogues & prediction_dialogues)}\n" \
60+
f"Make sure you are using the correct --subset [validation/test] " \
61+
f"and that you returned the correct dialogues ids in your prediction."
62+
)
63+
64+
################################################################################
65+
# EVALUATION METRICS #
66+
################################################################################
67+
68+
def compute_confusion_matrix(
69+
ref,
70+
pred,
71+
classes=allowed_classes,
72+
return_num_processed=False
73+
):
74+
mat = collections.defaultdict(lambda: {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0})
75+
num_processed_samples = 0
76+
77+
for index, row in ref.iterrows():
78+
num_processed_samples += 1
79+
for c in classes:
80+
predicted = pred.loc[index, f"class__{c}"]
81+
score = row.classes[c]
82+
if predicted:
83+
if score == 0:
84+
mat[c]['fp'] += 1
85+
elif score == 1:
86+
mat[c]['tp'] += 1
87+
elif score == 0.5:
88+
mat[c]['tp'] += score
89+
mat[c]['fp'] += score
90+
else:
91+
raise ValueError("weird.")
92+
else:
93+
if score == 0:
94+
mat[c]['tn'] += 1
95+
elif score == 1:
96+
mat[c]['fn'] += 1
97+
elif score == 0.5:
98+
mat[c]['tn'] += score
99+
mat[c]['fn'] += score
100+
else:
101+
raise ValueError("weird.")
102+
103+
mat = pd.DataFrame.from_dict(mat, orient='index')
104+
if return_num_processed:
105+
return mat, num_processed_samples
106+
else:
107+
return mat
108+
109+
def average_metric(mat, met, average, metric):
110+
if average is None:
111+
return met
112+
elif average == 'macro':
113+
return met.mean()
114+
elif average == 'micro':
115+
return metric(mat.sum(), average=None)
116+
elif average == 'all':
117+
return {
118+
avg: average_metric(mat, met, avg, metric)
119+
for avg in [None, 'micro', 'macro']
120+
}
121+
else:
122+
raise ValueError(average)
123+
124+
def precision(mat, average=None):
125+
_precision = mat['tp'] / (mat['tp'] + mat['fp'])
126+
127+
# 1.0 to precision when no predicted examples
128+
if isinstance(_precision, pd.Series):
129+
_precision = _precision.fillna(value=1.0)
130+
else:
131+
_precision = np.nan_to_num(_precision, nan=1.0)
132+
133+
return average_metric(
134+
mat,
135+
_precision,
136+
average=average, metric=precision
137+
)
138+
139+
def recall(mat, average=None):
140+
_recall = mat['tp'] / (mat['tp'] + mat['fn'])
141+
assert not np.isnan(_recall).any(), \
142+
f"Recall cannot have a NaN. That would mean an label has no occurence" \
143+
" on the test set."
144+
return average_metric(
145+
mat,
146+
_recall,
147+
average=average, metric=recall
148+
)
149+
150+
def f1(mat, average=None):
151+
p, r = precision(mat), recall(mat)
152+
_f1 = 2 * (p * r) / (p + r)
153+
154+
if isinstance(_f1, pd.Series):
155+
_f1 = _f1.fillna(value=0.0)
156+
else:
157+
_f1 = np.nan_to_num(_f1, nan=0.0)
158+
159+
return average_metric(
160+
mat,
161+
_f1,
162+
average=average, metric=f1
163+
)
164+
165+
def all_metrics(mat, average=None):
166+
return {
167+
k: globals()[k](mat, average=average)
168+
for k in ['f1', 'precision', 'recall']
169+
}
170+
171+
confusion = compute_confusion_matrix(
172+
ref=reference, pred=prediction,
173+
)
174+
175+
micro = all_metrics(confusion, average='micro')
176+
macro = all_metrics(confusion, average='macro')
177+
per_class = all_metrics(confusion, average=None)
178+
179+
print(f"Micro: {micro}")
180+
print(f"Macro: {macro}")
181+
182+
output_file = os.path.join(
183+
os.path.dirname(args.prediction.name), 'results.json'
184+
)
185+
186+
with open(output_file, 'w') as f:
187+
json.dump({
188+
'now': datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S'),
189+
'num_reference_samples': len(reference),
190+
'num_predicted_samples': len(prediction),
191+
'micro': micro,
192+
'macro': macro,
193+
'per_class': {
194+
metric: per_class[metric].to_dict()
195+
for metric in per_class
196+
},
197+
'args': str(args)
198+
}, f, indent='\t', ensure_ascii=False)
199+
200+
print(f"Outputs saved to {output_file=}.")

evaluation/predict_bert

+159
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
#!/usr/bin/env python3
2+
3+
import argparse
4+
5+
parser = argparse.ArgumentParser(
6+
description='Tool used to generate the predictions of a BERT model.'
7+
)
8+
parser.add_argument(
9+
'--reference_dataset', required=True, type=str,
10+
help='Path to the original reference dataset folder.'
11+
)
12+
parser.add_argument(
13+
'--subset', required=True, type=str,
14+
choices=['validation', 'test'],
15+
help='Subset to predict.'
16+
)
17+
parser.add_argument(
18+
'--model_path', required=True, type=str,
19+
help='Path of the model to use for prediction.'
20+
)
21+
parser.add_argument(
22+
'--output_dir', required=True, type=str,
23+
help='Output directory to save the results in.'
24+
)
25+
parser.add_argument(
26+
'--lowercase_text', action='store_true',
27+
help='Apply lowercase to all input text. Should match what was use when ' \
28+
'training / finetuning.'
29+
)
30+
parser.add_argument(
31+
'--device', default='cpu',
32+
help='Device to predict on.'
33+
)
34+
args = parser.parse_args()
35+
print(args)
36+
37+
################################################################################
38+
39+
import os
40+
import json
41+
import torch
42+
import datasets
43+
import transformers
44+
45+
import numpy as np
46+
import pandas as pd
47+
48+
################################################################################
49+
# LOADING THE MODEL, CONFIG AND TOKENIZER #
50+
################################################################################
51+
52+
with open(f'{args.model_path}/config.json') as f:
53+
tokenizer_name = json.load(f)['_name_or_path']
54+
55+
output_dir = os.path.join(
56+
args.output_dir,
57+
args.model_path.replace('/', '__'),
58+
args.subset,
59+
)
60+
os.makedirs(output_dir, exist_ok=True)
61+
62+
# preparing the models
63+
model = transformers.AutoModelForSequenceClassification.from_pretrained(
64+
args.model_path,
65+
).to(args.device)
66+
67+
if 'flaubert' in model.config.architectures[0].lower():
68+
model.sequence_summary.summary_type = 'mean'
69+
model.config.max_length = 256
70+
71+
tokenizer = transformers.AutoTokenizer.from_pretrained(
72+
tokenizer_name
73+
)
74+
75+
# evaluating on the annotator test set
76+
trainer = transformers.Trainer(
77+
model,
78+
args=torch.load(f'{args.model_path}/training_args.bin'),
79+
data_collator=transformers.DefaultDataCollator(),
80+
)
81+
82+
allowed_classes = set(trainer.model.config.id2label.values())
83+
classes = [f'class__{c}' for c in sorted(allowed_classes)]
84+
85+
def sigmoid(x):
86+
return 1/(1 + np.exp(-x))
87+
88+
def get_preds_ids(prediction):
89+
prediction = sigmoid(prediction)
90+
return np.argwhere(prediction > 0.5).reshape(-1)
91+
92+
def tokenize_function(examples):
93+
return tokenizer(
94+
examples["whisper_text"],
95+
padding="max_length",
96+
truncation=True,
97+
max_length=model.config.max_length
98+
)
99+
100+
def get_preds_ids(prediction):
101+
prediction = sigmoid(prediction)
102+
return np.argwhere(prediction > 0.5).reshape(-1)
103+
104+
def get_classes_names(model, ids):
105+
return set(map(lambda x: model.config.id2label[x], ids))
106+
107+
################################################################################
108+
# PREPARING THE DATASET #
109+
################################################################################
110+
111+
# preparing the data
112+
reference_file = os.path.join(args.reference_dataset, f"{args.subset}.json")
113+
if not os.path.exists(reference_file):
114+
raise FileNotFoundError(
115+
f"The dataset reference for {args.subset=} cannot be found at " \
116+
"{reference_file}."
117+
)
118+
reference = pd.read_json(reference_file, orient='index')
119+
120+
if args.lowercase_text:
121+
reference['whisper_text'] = reference['whisper_text'].str.lower()
122+
# hf_dataset = hf_dataset.map(lambda x: {'whisper_text': x['whisper_text'].lower()})
123+
124+
hf_dataset = datasets.Dataset.from_pandas(reference)
125+
126+
tokenized_hf_dataset = hf_dataset.map(
127+
tokenize_function, batched=True, batch_size=1000
128+
)
129+
130+
################################################################################
131+
# PREDICTING AND SAVING RESULTS #
132+
################################################################################
133+
134+
preds = trainer.predict(tokenized_hf_dataset)
135+
output = pd.DataFrame(columns=['text', *classes]).astype(
136+
{c: bool for c in classes}
137+
)
138+
139+
for (sample_id, sample), pred in zip(reference.iterrows(), preds.predictions):
140+
output.loc[sample_id, 'text'] = sample['whisper_text']
141+
for c in allowed_classes:
142+
output.loc[sample_id, f'class__{c}'] = False
143+
144+
for c in get_classes_names(trainer.model, get_preds_ids(pred)):
145+
output.loc[sample_id, f'class__{c}'] = True
146+
147+
output = output.to_dict(orient='index')
148+
with open(f"{output_dir}/predictions.json", "w") as f:
149+
json.dump(
150+
output, f,
151+
ensure_ascii=False, indent='\t'
152+
)
153+
154+
with open(f"{output_dir}/model-config.json", "w") as f:
155+
json.dump(
156+
trainer.model.config.to_dict(), f,
157+
ensure_ascii=False, indent='\t'
158+
)
159+

requirements.txt

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
datasets
2+
transformers
3+
pandas
4+
tqdm

0 commit comments

Comments
 (0)