Skip to content

Commit

Permalink
fix em bugs and add sep token
Browse files Browse the repository at this point in the history
  • Loading branch information
smarterliu committed Oct 16, 2024
1 parent 71438ec commit c0ac96b
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
6 changes: 5 additions & 1 deletion lm_eval/tasks/hotpot_qa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

f1_gen = evaluate.load("./metrics/f1")
exact_match = evaluate.load("./metrics/exact_match")
sep_tokens = ["<unused2>", "<0x02>", "<|reserved_special_token_2|>"]

def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
Expand Down Expand Up @@ -74,8 +75,11 @@ def _extract_facts(context):

def process_results(doc, results):
completion = results[0]
for sep_token in sep_tokens:
if sep_token in completion:
completion = completion.split(sep_token)[1]
ans = doc["answer"]
exact_score = exact_match(references=[ans], predictions=[completion])
exact_score = exact_match.compute(references=[ans], predictions=[completion])["exact_match"]
ans_toks = get_tokens(ans)
completion_toks = get_tokens(completion)
common = collections.Counter(ans_toks) & collections.Counter(completion_toks)
Expand Down
6 changes: 5 additions & 1 deletion lm_eval/tasks/nq_open/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

f1_gen = evaluate.load("./metrics/f1")
exact_match = evaluate.load("./metrics/exact_match")
sep_tokens = ["<unused2>", "<0x02>", "<|reserved_special_token_2|>"]

def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
Expand Down Expand Up @@ -65,8 +66,11 @@ def _extract_facts(docs):

def process_results(doc, results):
completion = results[0]
for sep_token in sep_tokens:
if sep_token in completion:
completion = completion.split(sep_token)[1]
ans = doc["answer"]
exact_score = exact_match(references=[ans], predictions=[completion])
exact_score = exact_match.compute(references=[ans], predictions=[completion])["exact_match"]
ans_toks = get_tokens(ans)
completion_toks = get_tokens(completion)
common = collections.Counter(ans_toks) & collections.Counter(completion_toks)
Expand Down
6 changes: 5 additions & 1 deletion lm_eval/tasks/triviaqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

f1_gen = evaluate.load("./metrics/f1")
exact_match = evaluate.load("./metrics/exact_match")
sep_tokens = ["<unused2>", "<0x02>", "<|reserved_special_token_2|>"]

def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
Expand Down Expand Up @@ -61,8 +62,11 @@ def _extract_facts(docs):

def process_results(doc, results):
completion = results[0]
for sep_token in sep_tokens:
if sep_token in completion:
completion = completion.split(sep_token)[1]
ans = doc["answer"]
exact_score = exact_match(references=[ans], predictions=[completion])
exact_score = exact_match.compute(references=[ans], predictions=[completion])["exact_match"]
ans_toks = get_tokens(ans)
completion_toks = get_tokens(completion)
common = collections.Counter(ans_toks) & collections.Counter(completion_toks)
Expand Down

0 comments on commit c0ac96b

Please sign in to comment.