From fcb4fda0311e0f2877308915e7a8d9bc8fdd49f3 Mon Sep 17 00:00:00 2001 From: Jasper Schulze Buschhoff Date: Thu, 19 Oct 2023 16:24:57 +0200 Subject: [PATCH] implemented flores200 --- lm_eval/tasks/opengptx/all_tasks_registry.py | 8 ++ lm_eval/tasks/opengptx/flores200.py | 120 +++++++++++++++++++ 2 files changed, 128 insertions(+) create mode 100644 lm_eval/tasks/opengptx/flores200.py diff --git a/lm_eval/tasks/opengptx/all_tasks_registry.py b/lm_eval/tasks/opengptx/all_tasks_registry.py index b05b47e5ec..701322cf5d 100644 --- a/lm_eval/tasks/opengptx/all_tasks_registry.py +++ b/lm_eval/tasks/opengptx/all_tasks_registry.py @@ -1,4 +1,5 @@ # OpenGPT-X tasks +from . import flores200 from . import german_europarl_ppl from . import german_ler_ppl from . import germanquad @@ -17,6 +18,12 @@ from . import xquad from . import xnli +######################################## +# Translation tasks +######################################## + +euro5_flores_benchmark = ["eng_Latn", "deu_Latn", "fra_Latn", "ita_Latn", "spa_Latn"] + TASK_REGISTRY_TMP = { # OpenGPT-X tasks @@ -40,6 +47,7 @@ "xstance_fr": x_stance.XStanceFR, **xquad.construct_tasks(), **xnli.construct_tasks(), + **flores200.create_tasks_from_list(euro5_flores_benchmark) } # add a prefix to tasks implemented by OpenGPT-X diff --git a/lm_eval/tasks/opengptx/flores200.py b/lm_eval/tasks/opengptx/flores200.py new file mode 100644 index 0000000000..ac3e734158 --- /dev/null +++ b/lm_eval/tasks/opengptx/flores200.py @@ -0,0 +1,120 @@ +""" +NOTE: This file implements the Flores200 translation task, see +https://github.com/facebookresearch/flores/tree/main/flores200. +""" +import pycountry +import datasets +from itertools import permutations +from pprint import pprint +from sacrebleu import sacrebleu +from lm_eval import metrics +from lm_eval.base import Task, rf +from lm_eval.tasks.translation import code_to_language +from typing import List + + + +######################################## +# Tasks +######################################## +class FloresTranslationTask(Task): + DATASET_PATH = "facebook/flores" + + def __init__(self, language_pair: str=None): + self.DATASET_NAME = self.language_pair = language_pair + self.src_code, self.tgt_code = language_pair.split("-") + self.src_lang = code_to_language(self.src_code[:3]) + self.tgt_lang = code_to_language(self.tgt_code[:3]) + + def has_validation_docs(self): + return True + + def has_test_docs(self): + return True + + def has_training_docs(self): + return False + + def test_docs(self): + return self.dataset["devtest"] + + def validation_docs(self): + return self.dataset["dev"] + + def doc_to_text(self, doc): + return f"{self.src_lang} phrase: " + doc[f"sentence_{self.src_lang}"]\ + + f"\n{self.tgt_lang} phrase:" + + def should_decontaminate(self): + return True + + def doc_to_decontamination_query(self, doc): + return doc[f"sentence_{self.src_code}"] + + def doc_to_target(self, doc): + return doc[f"sentence_{self.tgt_code}"] + + def construct_requests(self, doc, ctx): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + return rf.greedy_until(ctx, {"until": ["\n"]}) + + def process_results(self, doc, results): + # These metrics are corpus-level not sentence level, so we'll hide the + # results in this dict and compute the corpus score in the aggregate method + ref_pred = (self.doc_to_target(doc), results) + return { + "bleu": ref_pred, + "chrf": ref_pred, + "ter": ref_pred, + } + + def aggregation(self): + """ + :returns: {str: [float] -> float} + A dictionary where keys are the names of submetrics and values are + functions that aggregate a list of metrics + """ + return { + "bleu": metrics.bleu, + "chrf": metrics.chrf, + "ter": metrics.ter, + } + + def higher_is_better(self): + """ + :returns: {str: bool} + A dictionary where keys are the names of submetrics and values are + whether a higher value of the submetric is better + """ + return { + "bleu": True, + "chrf": True, + "ter": False, + } + + def __str__(self): + return f"Flores200 {self.src_lang} to {self.tgt_lang} Task" + +def create_translation_task(language_pair, version=0): + class TranslationTask(FloresTranslationTask): + VERSION = version + + def __init__(self): + super().__init__(language_pair) + + return TranslationTask + +def create_tasks_from_list(lang_list): + """Symmetrically create all flores200 tasks from a list of languages. + """ + return {f"flores200-{src}-{tgt}":create_translation_task(f"{src}-{tgt}") + for src, tgt in permutations(lang_list, 2)} \ No newline at end of file