-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathgenerations_evaluate_answerability.py
More file actions
77 lines (62 loc) · 2.48 KB
/
generations_evaluate_answerability.py
File metadata and controls
77 lines (62 loc) · 2.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import logging
from dataclasses import dataclass, field
from logging.config import fileConfig
from pathlib import Path
import pandas as pd
import simple_parsing
from sklearn.metrics import classification_report
from tqdm.contrib.logging import logging_redirect_tqdm
fileConfig("logging.ini")
logger = logging.getLogger(__name__)
@dataclass
class Args:
generation_file: Path
output_dir: Path = field(default=Path("out"))
qa_file: Path = field(default=Path("data/qa.jsonl"))
override: bool = False
no_answer_string: str = "No Answer"
def main(args: Args):
out_file = (
args.output_dir / f"metrics-answerability-{args.generation_file.name}.json"
)
if out_file.exists() and not args.override:
logger.info(f"Metrics file {out_file} already exists. Skipping.")
return
# Load the data
logger.info(f"Loading data from {args.generation_file} and {args.qa_file}.")
gen_df = pd.read_json(args.generation_file, lines=True)
# drop answerable column, we will get the label from qa_df
gen_df = gen_df.drop(columns=["answerable"])
qa_df = pd.read_json(args.qa_file, lines=True)
logger.info(f"qa_df shape: {qa_df.shape}")
logger.info("Preprocessing generations.")
# Add answer evidence and answer free form to the generated data
gen_df = pd.merge(
gen_df,
qa_df[["paper_id", "question_id", "answerable_mapped"]],
on=["paper_id", "question_id"],
how="left",
)
# Drop rows with missing answerable_mapped, as we can't evaluate them
gen_df = gen_df.dropna(subset=["answerable_mapped"])
logger.info(f"gen_df shape: {gen_df.shape}")
y_pred = 1 - gen_df.generation.str.contains(args.no_answer_string).astype(int)
y_true = gen_df.answerable_mapped.astype(int)
logger.info(f"y_pred shape: {y_pred.shape} {y_pred.sum()=}")
logger.info(f"y_true shape: {y_true.shape} {y_true.sum()=}")
clf_report_kwargs = dict(
y_true=y_true,
y_pred=y_pred,
target_names=["Unanswerable", "Answerable"],
zero_division=0,
)
clf_report_str = classification_report(**clf_report_kwargs)
logger.info("\n" + clf_report_str)
logger.info("Saving classification report.")
clf_report = classification_report(**clf_report_kwargs, output_dict=True)
pd.DataFrame(clf_report).to_json(out_file, indent=2)
if __name__ == "__main__":
args, _ = simple_parsing.parse_known_args(Args)
with logging_redirect_tqdm():
logger.info(args)
main(args)