Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 2 additions & 34 deletions evaluation/benchmarks/loogle/calculate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from nltk.translate.meteor_score import single_meteor_score
from rouge import Rouge

from .parsing import get_exact_match, get_partial_match


# Code below is adapted from https://github.com/bigai-nlco/LooGLE/blob/main/Evaluation/automatic_metrics.py
def get_bleu_score(reference, hypothesis):
Expand Down Expand Up @@ -37,40 +39,6 @@ def get_meteor_score(reference, hypothesis):
return {"meteor": float(meteor)}


def get_exact_match(reference, hypothesis):
try:
reference = eval(reference)
count = len(reference)
hypothesis = eval(hypothesis)
assert isinstance(hypothesis, dict)
except Exception:
return 0, 1

exact_score_count = 0
for key, value in reference.items():
if hypothesis.get(key) == value:
exact_score_count += 1
return exact_score_count, count


def get_partial_match(reference, hypothesis):
reference = eval(reference)
count = len(reference)
try:
hypothesis = eval(hypothesis)
assert isinstance(hypothesis, dict)
partial_score_count = 0
for key in reference:
if key in hypothesis:
true_set = set(reference[key].split())
pred_set = set(hypothesis[key].split())
if len(true_set.intersection(pred_set)) > 0:
partial_score_count += 1
return partial_score_count, count
except Exception:
return 0, count


def try_except_metric(metric_fn):
def wrapped_metric(answer, predicted_answer):
try:
Expand Down
22 changes: 20 additions & 2 deletions evaluation/benchmarks/loogle/create_huggingface_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import pandas as pd
from datasets import Dataset, load_dataset

from .parsing import parse_qa_pairs

# Templates based on https://github.com/bigai-nlco/LooGLE/blob/main/config/task2prompt.json
context_prompt = {
"shortdep_qa": "Please answer the question based on the long texts below. \n{input}",
Expand All @@ -32,15 +34,17 @@
# Source: https://github.com/bigai-nlco/LooGLE/blob/main/config/task2maxlen.json
max_new_tokens = {"shortdep_qa": 300, "longdep_qa": 500, "longdep_summarization": 500, "shortdep_cloze": 50}

for task in ["shortdep_qa", "longdep_qa", "shortdep_cloze", "longdep_summarization"]:
TASKS = ("shortdep_qa", "longdep_qa", "shortdep_cloze", "longdep_summarization")


def build_task_dataframe(task: str) -> pd.DataFrame:
df = load_dataset("bigainlco/LooGLE", task, split="test", trust_remote_code=True).to_pandas()

if task == "longdep_summarization":
df["question"] = ""
df = df.rename(columns={"output": "answer", "input": "context"})
else:
df["qa_pairs"] = df["qa_pairs"].apply(lambda x: eval(x) if x != "none" else [{"Q": "", "A": "", "S": [""]}])
df["qa_pairs"] = df["qa_pairs"].apply(parse_qa_pairs)
df = df.explode("qa_pairs")
df = pd.concat([df.drop(["qa_pairs"], axis=1), df["qa_pairs"].apply(pd.Series)], axis=1)
df = df.rename(columns={"A": "answer", "Q": "question", "input": "context"})
Expand All @@ -53,7 +57,21 @@
df = df[["context", "question", "answer_prefix", "answer"]]
df["task"] = task
df["max_new_tokens"] = max_new_tokens[task]
return df


def push_task_dataset(task: str) -> None:
df = build_task_dataframe(task)

# Push to hub
dataset = Dataset.from_pandas(df)
dataset.push_to_hub("simonjegou/loogle", config_name=task, split="test")


def main() -> None:
for task in TASKS:
push_task_dataset(task)


if __name__ == "__main__":
main()
68 changes: 68 additions & 0 deletions evaluation/benchmarks/loogle/parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


import ast
import json
from typing import Any


def parse_loogle_literal(value: str) -> Any:
"""Parse JSON first, then fall back to Python literals for model outputs."""
try:
return json.loads(value)
except json.JSONDecodeError:
return ast.literal_eval(value)


def parse_cloze_answers(value: str) -> dict[str, str]:
parsed = parse_loogle_literal(value)
if not isinstance(parsed, dict):
raise ValueError("Expected a dictionary of cloze answers.")
if not all(isinstance(key, str) and isinstance(answer, str) for key, answer in parsed.items()):
raise ValueError("Expected string keys and values in cloze answers.")
return parsed


def get_exact_match(reference: str, hypothesis: str) -> tuple[int, int]:
try:
reference_answers = parse_cloze_answers(reference)
count = len(reference_answers)
hypothesis_answers = parse_cloze_answers(hypothesis)
except (SyntaxError, TypeError, ValueError):
return 0, 1

exact_score_count = 0
for key, value in reference_answers.items():
if hypothesis_answers.get(key) == value:
exact_score_count += 1
return exact_score_count, count


def get_partial_match(reference: str, hypothesis: str) -> tuple[int, int]:
reference_answers = parse_cloze_answers(reference)
count = len(reference_answers)
try:
hypothesis_answers = parse_cloze_answers(hypothesis)
partial_score_count = 0
for key in reference_answers:
if key in hypothesis_answers:
true_set = set(reference_answers[key].split())
pred_set = set(hypothesis_answers[key].split())
if len(true_set.intersection(pred_set)) > 0:
partial_score_count += 1
return partial_score_count, count
except (SyntaxError, TypeError, ValueError):
return 0, count


def parse_qa_pairs(value: str) -> list[dict[str, Any]]:
if value == "none":
return [{"Q": "", "A": "", "S": [""]}]

parsed = ast.literal_eval(value)
if not isinstance(parsed, list):
raise ValueError("Expected qa_pairs to be a list.")
if not all(isinstance(item, dict) for item in parsed):
raise ValueError("Expected qa_pairs to contain dictionaries.")
return parsed
59 changes: 59 additions & 0 deletions tests/test_loogle_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


import os

from evaluation.benchmarks.loogle.parsing import get_exact_match, get_partial_match, parse_qa_pairs


def test_get_exact_match_accepts_python_literal_predictions():
reference = '{"<mask-0>": "Bob"}'
hypothesis = "{'<mask-0>': 'Bob'}"

assert get_exact_match(reference, hypothesis) == (1, 1)


def test_get_exact_match_returns_zero_if_prediction_is_not_a_literal(monkeypatch):
monkeypatch.delenv("LOOGLE_EVAL_EXECUTED", raising=False)
reference = '{"<mask-0>": "Bob"}'
hypothesis = (
"(__import__('os').environ.__setitem__('LOOGLE_EVAL_EXECUTED', '1'), {'<mask-0>': 'Bob'})[1]"
)

assert get_exact_match(reference, hypothesis) == (0, 1)
assert os.environ.get("LOOGLE_EVAL_EXECUTED") is None


def test_get_partial_match_accepts_python_literal_predictions():
reference = '{"<mask-0>": "Bob Smith"}'
hypothesis = "{'<mask-0>': 'Smith'}"

assert get_partial_match(reference, hypothesis) == (1, 1)


def test_parse_qa_pairs_returns_default_if_none():
assert parse_qa_pairs("none") == [{"Q": "", "A": "", "S": [""]}]


def test_parse_qa_pairs_accepts_python_literal_lists():
value = "[{'Q': 'q', 'A': 'a', 'S': ['s']}]"

assert parse_qa_pairs(value) == [{"Q": "q", "A": "a", "S": ["s"]}]


def test_parse_qa_pairs_rejects_non_literal_payload(monkeypatch):
monkeypatch.delenv("LOOGLE_QA_PAIRS_EVAL_EXECUTED", raising=False)
payload = (
"(__import__('os').environ.__setitem__('LOOGLE_QA_PAIRS_EVAL_EXECUTED', '1'), "
"[{'Q': 'q', 'A': 'a', 'S': ['s']}])[1]"
)

try:
parse_qa_pairs(payload)
except (SyntaxError, ValueError):
pass
else:
raise AssertionError("parse_qa_pairs should reject non-literal payloads.")

assert os.environ.get("LOOGLE_QA_PAIRS_EVAL_EXECUTED") is None