Skip to content

Commit

Permalink
implemented flores200
Browse files Browse the repository at this point in the history
  • Loading branch information
jjbuschhoff committed Oct 19, 2023
1 parent 93e3595 commit fcb4fda
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 0 deletions.
8 changes: 8 additions & 0 deletions lm_eval/tasks/opengptx/all_tasks_registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# OpenGPT-X tasks
from . import flores200
from . import german_europarl_ppl
from . import german_ler_ppl
from . import germanquad
Expand All @@ -17,6 +18,12 @@
from . import xquad
from . import xnli

########################################
# Translation tasks
########################################

euro5_flores_benchmark = ["eng_Latn", "deu_Latn", "fra_Latn", "ita_Latn", "spa_Latn"]


TASK_REGISTRY_TMP = {
# OpenGPT-X tasks
Expand All @@ -40,6 +47,7 @@
"xstance_fr": x_stance.XStanceFR,
**xquad.construct_tasks(),
**xnli.construct_tasks(),
**flores200.create_tasks_from_list(euro5_flores_benchmark)
}

# add a prefix to tasks implemented by OpenGPT-X
Expand Down
120 changes: 120 additions & 0 deletions lm_eval/tasks/opengptx/flores200.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""
NOTE: This file implements the Flores200 translation task, see
https://github.com/facebookresearch/flores/tree/main/flores200.
"""
import pycountry
import datasets
from itertools import permutations
from pprint import pprint
from sacrebleu import sacrebleu
from lm_eval import metrics
from lm_eval.base import Task, rf
from lm_eval.tasks.translation import code_to_language
from typing import List



########################################
# Tasks
########################################
class FloresTranslationTask(Task):
DATASET_PATH = "facebook/flores"

def __init__(self, language_pair: str=None):
self.DATASET_NAME = self.language_pair = language_pair
self.src_code, self.tgt_code = language_pair.split("-")
self.src_lang = code_to_language(self.src_code[:3])
self.tgt_lang = code_to_language(self.tgt_code[:3])

def has_validation_docs(self):
return True

def has_test_docs(self):
return True

def has_training_docs(self):
return False

def test_docs(self):
return self.dataset["devtest"]

def validation_docs(self):
return self.dataset["dev"]

def doc_to_text(self, doc):
return f"{self.src_lang} phrase: " + doc[f"sentence_{self.src_lang}"]\
+ f"\n{self.tgt_lang} phrase:"

def should_decontaminate(self):
return True

def doc_to_decontamination_query(self, doc):
return doc[f"sentence_{self.src_code}"]

def doc_to_target(self, doc):
return doc[f"sentence_{self.tgt_code}"]

def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
return rf.greedy_until(ctx, {"until": ["\n"]})

def process_results(self, doc, results):
# These metrics are corpus-level not sentence level, so we'll hide the
# results in this dict and compute the corpus score in the aggregate method
ref_pred = (self.doc_to_target(doc), results)
return {
"bleu": ref_pred,
"chrf": ref_pred,
"ter": ref_pred,
}

def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"bleu": metrics.bleu,
"chrf": metrics.chrf,
"ter": metrics.ter,
}

def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"bleu": True,
"chrf": True,
"ter": False,
}

def __str__(self):
return f"Flores200 {self.src_lang} to {self.tgt_lang} Task"

def create_translation_task(language_pair, version=0):
class TranslationTask(FloresTranslationTask):
VERSION = version

def __init__(self):
super().__init__(language_pair)

return TranslationTask

def create_tasks_from_list(lang_list):
"""Symmetrically create all flores200 tasks from a list of languages.
"""
return {f"flores200-{src}-{tgt}":create_translation_task(f"{src}-{tgt}")
for src, tgt in permutations(lang_list, 2)}

0 comments on commit fcb4fda

Please sign in to comment.