Skip to content

Commit

Permalink
satisfy linter
Browse files Browse the repository at this point in the history
  • Loading branch information
jjbuschhoff committed Oct 19, 2023
1 parent fcb4fda commit aa6f2a7
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 18 deletions.
2 changes: 1 addition & 1 deletion lm_eval/tasks/opengptx/all_tasks_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
"xstance_fr": x_stance.XStanceFR,
**xquad.construct_tasks(),
**xnli.construct_tasks(),
**flores200.create_tasks_from_list(euro5_flores_benchmark)
**flores200.create_tasks_from_list(euro5_flores_benchmark),
}

# add a prefix to tasks implemented by OpenGPT-X
Expand Down
39 changes: 22 additions & 17 deletions lm_eval/tasks/opengptx/flores200.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,47 +13,49 @@
from typing import List



########################################
# Tasks
########################################
class FloresTranslationTask(Task):
DATASET_PATH = "facebook/flores"
def __init__(self, language_pair: str=None):

def __init__(self, language_pair: str = None):
self.DATASET_NAME = self.language_pair = language_pair
self.src_code, self.tgt_code = language_pair.split("-")
self.src_lang = code_to_language(self.src_code[:3])
self.tgt_lang = code_to_language(self.tgt_code[:3])

def has_validation_docs(self):
return True

def has_test_docs(self):
return True

def has_training_docs(self):
return False

def test_docs(self):
return self.dataset["devtest"]

def validation_docs(self):
return self.dataset["dev"]

def doc_to_text(self, doc):
return f"{self.src_lang} phrase: " + doc[f"sentence_{self.src_lang}"]\
return (
f"{self.src_lang} phrase: "
+ doc[f"sentence_{self.src_lang}"]
+ f"\n{self.tgt_lang} phrase:"

)

def should_decontaminate(self):
return True

def doc_to_decontamination_query(self, doc):
return doc[f"sentence_{self.src_code}"]

def doc_to_target(self, doc):
return doc[f"sentence_{self.tgt_code}"]

def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
Expand All @@ -66,7 +68,7 @@ def construct_requests(self, doc, ctx):
part of the document for `doc`.
"""
return rf.greedy_until(ctx, {"until": ["\n"]})

def process_results(self, doc, results):
# These metrics are corpus-level not sentence level, so we'll hide the
# results in this dict and compute the corpus score in the aggregate method
Expand Down Expand Up @@ -104,6 +106,7 @@ def higher_is_better(self):
def __str__(self):
return f"Flores200 {self.src_lang} to {self.tgt_lang} Task"


def create_translation_task(language_pair, version=0):
class TranslationTask(FloresTranslationTask):
VERSION = version
Expand All @@ -113,8 +116,10 @@ def __init__(self):

return TranslationTask


def create_tasks_from_list(lang_list):
"""Symmetrically create all flores200 tasks from a list of languages.
"""
return {f"flores200-{src}-{tgt}":create_translation_task(f"{src}-{tgt}")
for src, tgt in permutations(lang_list, 2)}
"""Symmetrically create all flores200 tasks from a list of languages."""
return {
f"flores200-{src}-{tgt}": create_translation_task(f"{src}-{tgt}")
for src, tgt in permutations(lang_list, 2)
}

0 comments on commit aa6f2a7

Please sign in to comment.