forked from EleutherAI/lm-evaluation-harness
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
93e3595
commit fcb4fda
Showing
2 changed files
with
128 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
""" | ||
NOTE: This file implements the Flores200 translation task, see | ||
https://github.com/facebookresearch/flores/tree/main/flores200. | ||
""" | ||
import pycountry | ||
import datasets | ||
from itertools import permutations | ||
from pprint import pprint | ||
from sacrebleu import sacrebleu | ||
from lm_eval import metrics | ||
from lm_eval.base import Task, rf | ||
from lm_eval.tasks.translation import code_to_language | ||
from typing import List | ||
|
||
|
||
|
||
######################################## | ||
# Tasks | ||
######################################## | ||
class FloresTranslationTask(Task): | ||
DATASET_PATH = "facebook/flores" | ||
|
||
def __init__(self, language_pair: str=None): | ||
self.DATASET_NAME = self.language_pair = language_pair | ||
self.src_code, self.tgt_code = language_pair.split("-") | ||
self.src_lang = code_to_language(self.src_code[:3]) | ||
self.tgt_lang = code_to_language(self.tgt_code[:3]) | ||
|
||
def has_validation_docs(self): | ||
return True | ||
|
||
def has_test_docs(self): | ||
return True | ||
|
||
def has_training_docs(self): | ||
return False | ||
|
||
def test_docs(self): | ||
return self.dataset["devtest"] | ||
|
||
def validation_docs(self): | ||
return self.dataset["dev"] | ||
|
||
def doc_to_text(self, doc): | ||
return f"{self.src_lang} phrase: " + doc[f"sentence_{self.src_lang}"]\ | ||
+ f"\n{self.tgt_lang} phrase:" | ||
|
||
def should_decontaminate(self): | ||
return True | ||
|
||
def doc_to_decontamination_query(self, doc): | ||
return doc[f"sentence_{self.src_code}"] | ||
|
||
def doc_to_target(self, doc): | ||
return doc[f"sentence_{self.tgt_code}"] | ||
|
||
def construct_requests(self, doc, ctx): | ||
"""Uses RequestFactory to construct Requests and returns an iterable of | ||
Requests which will be sent to the LM. | ||
:param doc: | ||
The document as returned from training_docs, validation_docs, or test_docs. | ||
:param ctx: str | ||
The context string, generated by fewshot_context. This includes the natural | ||
language description, as well as the few shot examples, and the question | ||
part of the document for `doc`. | ||
""" | ||
return rf.greedy_until(ctx, {"until": ["\n"]}) | ||
|
||
def process_results(self, doc, results): | ||
# These metrics are corpus-level not sentence level, so we'll hide the | ||
# results in this dict and compute the corpus score in the aggregate method | ||
ref_pred = (self.doc_to_target(doc), results) | ||
return { | ||
"bleu": ref_pred, | ||
"chrf": ref_pred, | ||
"ter": ref_pred, | ||
} | ||
|
||
def aggregation(self): | ||
""" | ||
:returns: {str: [float] -> float} | ||
A dictionary where keys are the names of submetrics and values are | ||
functions that aggregate a list of metrics | ||
""" | ||
return { | ||
"bleu": metrics.bleu, | ||
"chrf": metrics.chrf, | ||
"ter": metrics.ter, | ||
} | ||
|
||
def higher_is_better(self): | ||
""" | ||
:returns: {str: bool} | ||
A dictionary where keys are the names of submetrics and values are | ||
whether a higher value of the submetric is better | ||
""" | ||
return { | ||
"bleu": True, | ||
"chrf": True, | ||
"ter": False, | ||
} | ||
|
||
def __str__(self): | ||
return f"Flores200 {self.src_lang} to {self.tgt_lang} Task" | ||
|
||
def create_translation_task(language_pair, version=0): | ||
class TranslationTask(FloresTranslationTask): | ||
VERSION = version | ||
|
||
def __init__(self): | ||
super().__init__(language_pair) | ||
|
||
return TranslationTask | ||
|
||
def create_tasks_from_list(lang_list): | ||
"""Symmetrically create all flores200 tasks from a list of languages. | ||
""" | ||
return {f"flores200-{src}-{tgt}":create_translation_task(f"{src}-{tgt}") | ||
for src, tgt in permutations(lang_list, 2)} |