diff --git a/evaluate.py b/evaluate.py index 21188d5..eefdb6b 100644 --- a/evaluate.py +++ b/evaluate.py @@ -25,6 +25,7 @@ from seqeval.metrics import recall_score from third_party.evaluate_mlqa import evaluate as mlqa_eval from third_party.evaluate_squad import evaluate as squad_eval +from third_party.utils_mewslix import evaluate as mewslix_eval def read_tag(file): @@ -108,6 +109,12 @@ def mlqa_em_f1(labels, predictions, language): return mlqa_eval(labels, predictions, language) +def mewslix_map20(labels, predictions, language=None): + del language + mrr = mewslix_eval(labels, predictions, k=20) + return {'map@20': mrr * 100} + + XTREME_GROUP2TASK = { 'classification': ['pawsx', 'xnli'], 'tagging': ['udpos', 'panx'], @@ -120,7 +127,7 @@ def mlqa_em_f1(labels, predictions, language): 'classification': ['xnli', 'xcopa'], 'tagging': ['udpos', 'panx'], 'qa': ['xquad', 'mlqa', 'tydiqa'], - 'retrieval': ['tatoeba'], + 'retrieval': ['tatoeba', 'mewslix'], 'multi_choice': ['xcopa'], } @@ -162,7 +169,7 @@ def mlqa_em_f1(labels, predictions, language): 'ro'.split(','), 'xcopa': 'et,ht,id,it,qu,sw,ta,th,tr,vi,zh'.split(','), 'lareqa': [], - 'mewslix': [], + 'mewslix': 'ar,de,en,es,fa,ja,pl,ro,ta,tr,uk'.split(','), 'xquad': 'en,es,de,el,ru,tr,ar,vi,th,zh,hi,ro'.split(','), 'mlqa': 'en,es,de,ar,hi,vi,zh'.split(','), 'tydiqa': 'en,ar,bn,fi,id,ko,ru,sw,te'.split(','), @@ -183,6 +190,7 @@ def mlqa_em_f1(labels, predictions, language): 'bucc2018': read_label, 'tatoeba': read_label, 'xquad': read_squad, + 'mewslix': read_squad, 'mlqa': read_squad, 'tydiqa': read_squad, 'xcopa': read_xcopa, @@ -197,6 +205,7 @@ def mlqa_em_f1(labels, predictions, language): 'bucc2018': bucc_f1, 'tatoeba': accuracy, 'xquad': squad_em_f1, + 'mewslix': mewslix_map20, 'mlqa': mlqa_em_f1, 'tydiqa': squad_em_f1, 'xcopa': accuracy, @@ -219,7 +228,7 @@ def evaluate_one_task(prediction_file, label_file, task, language=None): """ predictions = READER_FUNCTION[task](prediction_file) labels = READER_FUNCTION[task](label_file) - if task not in ['bucc2018', 'mlqa', 'tydiqa', 'xquad']: + if task not in ['bucc2018', 'mewslix', 'mlqa', 'tydiqa', 'xquad']: assert len(predictions) == len(labels), ( 'Number of examples in {} and {} not matched in {} task'.format( prediction_file, label_file, task)) @@ -227,6 +236,15 @@ def evaluate_one_task(prediction_file, label_file, task, language=None): return result +def get_suffix(task, group2task): + if task in group2task['qa'] or task in ('mewslix',): + return 'json' + elif 'multi_choice' in group2task and task in group2task['multi_choice']: + return 'jsonl' + else: + return 'tsv' + + def evaluate(prediction_folder, label_folder, xtreme_version, verbose=False): """Evaluate on all tasks if available. @@ -250,12 +268,7 @@ def evaluate(prediction_folder, label_folder, xtreme_version, verbose=False): detailed_scores = {} for task, langs in task2langs.items(): if task in prediction_tasks and task in label_tasks: - if task in group2task['qa']: - suffix = 'json' - elif 'multi_choice' in group2task and task in group2task['multi_choice']: - suffix = 'jsonl' - else: - suffix = 'tsv' + suffix = get_suffix(task, group2task) # collect scores over all languages score = collections.defaultdict(dict) for lg in langs: diff --git a/evaluate_test.py b/evaluate_test.py index 240d143..6d4874e 100644 --- a/evaluate_test.py +++ b/evaluate_test.py @@ -20,12 +20,15 @@ from absl.testing import absltest from absl.testing import parameterized from xtreme.evaluate import evaluate_one_task +from xtreme.evaluate import get_suffix from xtreme.evaluate import XTREME_GROUP2TASK +from xtreme.evaluate import XTREME_R_GROUP2TASK +from xtreme.evaluate import XTREME_R_TASK2LANGS from xtreme.evaluate import XTREME_TASK2LANGS DATA_DIR = './/mock_test_data' -# Mock submission scores for testing +# Mock submission scores for testing XTREME. TASK2AVG_SCORES = { 'pawsx': {'avg_accuracy': 51.42857142857143}, 'xnli': {'avg_accuracy': 30.666666666666668}, @@ -40,6 +43,24 @@ 'tydiqa': {'avg_exact_match': 88.88888888888889, 'avg_f1': 97.22222222222223} } +# Mock submission scores for testing XTREME-R. + # TODO(ruder): Update data/numbers for tasks with added languages (UD-POS, + # PANX, Tatoeba, and XQuAD) and for new tasks (XCOPA, LAReQA). +XTREME_R_TASK2AVG_SCORES = { + 'xnli': {'avg_accuracy': 30.666666666666668}, + # 'panx': {'avg_f1': 57.50793650793652, 'avg_precision': 54.729166666666664, + # 'avg_recall': 62.750000000000014}, + # 'udpos': {'avg_f1': 70.21746048354693, 'avg_precision': 71.02232625883823, + # 'avg_recall': 69.54982073976082}, + # 'tatoeba': {'avg_accuracy': 53.611111111111114}, + # 'xcopa' + # 'lareqa' + 'mewslix': {'avg_map@20': 14.39025156130419}, + # 'xquad': {'avg_exact_match': 77.27272727272727, 'avg_f1': 79.9586776859504}, + 'mlqa': {'avg_exact_match': 57.142857142857146, 'avg_f1': 81.76870748299321}, + 'tydiqa': {'avg_exact_match': 88.88888888888889, 'avg_f1': 97.22222222222223} +} + class EvaluateTest(parameterized.TestCase): """Test cases for evaluate.py.""" @@ -54,7 +75,7 @@ class EvaluateTest(parameterized.TestCase): ('XQuAD', 'xquad'), ('MLQA', 'mlqa'), ('TyDiQA', 'tydiqa')) - def testTask(self, task): + def testXtremeTask(self, task): data_dir = os.path.join(absltest.get_default_test_srcdir(), DATA_DIR) suffix = 'json' if task in XTREME_GROUP2TASK['qa'] else 'tsv' score = collections.defaultdict(dict) @@ -71,5 +92,32 @@ def testTask(self, task): self.assertEqual(avg_score, TASK2AVG_SCORES[task]) + @parameterized.named_parameters( + ('XNLI', 'xnli'), + # ('PANX', 'panx'), + # ('UDPOS', 'udpos'), + # ('Tatoeba', 'tatoeba'), + # ('XCOPA', 'xcopa'), + # ('LAReQA', 'lareqa'), + ('Mewsli-X', 'mewslix'), + # ('XQuAD', 'xquad'), + ('MLQA', 'mlqa'), + ('TyDiQA', 'tydiqa')) + def testXtremeRTask(self, task): + data_dir = os.path.join(absltest.get_default_test_srcdir(), DATA_DIR) + suffix = get_suffix(task, XTREME_R_GROUP2TASK) + score = collections.defaultdict(dict) + for lg in XTREME_R_TASK2LANGS[task]: + pred_file = os.path.join(data_dir, 'predictions', task, + f'test-{lg}.{suffix}') + label_file = os.path.join(data_dir, 'labels', task, f'test-{lg}.{suffix}') + score_lg = evaluate_one_task(pred_file, label_file, task, language=lg) + for metric in score_lg: + score[metric][lg] = score_lg[metric] + avg_score = {} + for m in score: + avg_score[f'avg_{m}'] = sum(score[m].values()) / len(score[m]) + self.assertEqual(avg_score, XTREME_R_TASK2AVG_SCORES[task]) + if __name__ == '__main__': absltest.main() diff --git a/mock_test_data/labels/mewslix/test-ar.json b/mock_test_data/labels/mewslix/test-ar.json new file mode 100644 index 0000000..7c1deee --- /dev/null +++ b/mock_test_data/labels/mewslix/test-ar.json @@ -0,0 +1 @@ +{"8e18e51eced73e6495df0043192edbfe": ["Q46930"]} \ No newline at end of file diff --git a/mock_test_data/labels/mewslix/test-de.json b/mock_test_data/labels/mewslix/test-de.json new file mode 100644 index 0000000..cd5111b --- /dev/null +++ b/mock_test_data/labels/mewslix/test-de.json @@ -0,0 +1 @@ +{"4be5a1742223cc3a8c01e6bf9c6e3f27": ["Q156913"]} \ No newline at end of file diff --git a/mock_test_data/labels/mewslix/test-en.json b/mock_test_data/labels/mewslix/test-en.json new file mode 100644 index 0000000..be723eb --- /dev/null +++ b/mock_test_data/labels/mewslix/test-en.json @@ -0,0 +1 @@ +{"64ca9e2f229acf8e39c2a3d2e45f81e7": ["Q720285"]} \ No newline at end of file diff --git a/mock_test_data/labels/mewslix/test-es.json b/mock_test_data/labels/mewslix/test-es.json new file mode 100644 index 0000000..9290987 --- /dev/null +++ b/mock_test_data/labels/mewslix/test-es.json @@ -0,0 +1 @@ +{"4a2d7fd3e4791f09bc3c804a15d647ef": ["Q786"]} \ No newline at end of file diff --git a/mock_test_data/labels/mewslix/test-fa.json b/mock_test_data/labels/mewslix/test-fa.json new file mode 100644 index 0000000..d6bc05a --- /dev/null +++ b/mock_test_data/labels/mewslix/test-fa.json @@ -0,0 +1 @@ +{"d35cc57a7869168ddeb8143c1b2260f3": ["Q76"]} \ No newline at end of file diff --git a/mock_test_data/labels/mewslix/test-ja.json b/mock_test_data/labels/mewslix/test-ja.json new file mode 100644 index 0000000..633dbc8 --- /dev/null +++ b/mock_test_data/labels/mewslix/test-ja.json @@ -0,0 +1 @@ +{"d0e7a9dd0359610c53bba176d702dfce": ["Q174691"]} \ No newline at end of file diff --git a/mock_test_data/labels/mewslix/test-pl.json b/mock_test_data/labels/mewslix/test-pl.json new file mode 100644 index 0000000..623e696 --- /dev/null +++ b/mock_test_data/labels/mewslix/test-pl.json @@ -0,0 +1 @@ +{"64232b8a3c3ee67f76f96ccd963b78f7": ["Q1362561"]} \ No newline at end of file diff --git a/mock_test_data/labels/mewslix/test-ro.json b/mock_test_data/labels/mewslix/test-ro.json new file mode 100644 index 0000000..a14d6af --- /dev/null +++ b/mock_test_data/labels/mewslix/test-ro.json @@ -0,0 +1 @@ +{"ebd92132adbb679fdd090503cd925f81": ["Q185007"]} \ No newline at end of file diff --git a/mock_test_data/labels/mewslix/test-ta.json b/mock_test_data/labels/mewslix/test-ta.json new file mode 100644 index 0000000..ce6c1d4 --- /dev/null +++ b/mock_test_data/labels/mewslix/test-ta.json @@ -0,0 +1 @@ +{"12760cb39680a822c3cd0c8495cf1b4b": ["Q11468"]} \ No newline at end of file diff --git a/mock_test_data/labels/mewslix/test-tr.json b/mock_test_data/labels/mewslix/test-tr.json new file mode 100644 index 0000000..35fadf3 --- /dev/null +++ b/mock_test_data/labels/mewslix/test-tr.json @@ -0,0 +1 @@ +{"9f39acb0fef259aaf24224fe41954f6c": ["Q258"]} \ No newline at end of file diff --git a/mock_test_data/labels/mewslix/test-uk.json b/mock_test_data/labels/mewslix/test-uk.json new file mode 100644 index 0000000..1e69387 --- /dev/null +++ b/mock_test_data/labels/mewslix/test-uk.json @@ -0,0 +1 @@ +{"9f4dba86a6d21cfd246353403da46abd": ["Q1899"]} \ No newline at end of file diff --git a/mock_test_data/predictions/mewslix/test-ar.json b/mock_test_data/predictions/mewslix/test-ar.json new file mode 100644 index 0000000..6aaa93a --- /dev/null +++ b/mock_test_data/predictions/mewslix/test-ar.json @@ -0,0 +1 @@ +{"8e18e51eced73e6495df0043192edbfe": ["Q4963862", "Q42309905", "Q45789", "Q13403337", "Q5564588", "Q4009605", "Q1635932", "Q4980057", "Q5958027", "Q233750", "Q2922959", "Q203023", "Q2425422", "Q2340576", "Q4639323", "Q46930", "Q66891", "Q5423986", "Q15556629", "Q1347825"]} \ No newline at end of file diff --git a/mock_test_data/predictions/mewslix/test-de.json b/mock_test_data/predictions/mewslix/test-de.json new file mode 100644 index 0000000..a6a5ad7 --- /dev/null +++ b/mock_test_data/predictions/mewslix/test-de.json @@ -0,0 +1 @@ +{"4be5a1742223cc3a8c01e6bf9c6e3f27": ["Q11490423", "Q156913", "Q490356", "Q16222746", "Q4873731", "Q2102531", "Q209944", "Q4630241", "Q9033638", "Q18249334", "Q65216438", "Q333185", "Q2530561", "Q20013418", "Q10826362", "Q2575270", "Q2914850", "Q55697199", "Q853167", "Q111730"]} \ No newline at end of file diff --git a/mock_test_data/predictions/mewslix/test-en.json b/mock_test_data/predictions/mewslix/test-en.json new file mode 100644 index 0000000..39729c6 --- /dev/null +++ b/mock_test_data/predictions/mewslix/test-en.json @@ -0,0 +1 @@ +{"64ca9e2f229acf8e39c2a3d2e45f81e7": ["Q82674", "Q13551861", "Q2418898", "Q198748", "Q1146387", "Q6730240", "Q6769706", "Q2315496", "Q3375182", "Q711611", "Q55732114", "Q720285", "Q4760035", "Q28670149", "Q375278", "Q260559", "Q82840", "Q878942", "Q269810", "Q427535"]} \ No newline at end of file diff --git a/mock_test_data/predictions/mewslix/test-es.json b/mock_test_data/predictions/mewslix/test-es.json new file mode 100644 index 0000000..d935182 --- /dev/null +++ b/mock_test_data/predictions/mewslix/test-es.json @@ -0,0 +1 @@ +{"4a2d7fd3e4791f09bc3c804a15d647ef": ["Q6151759", "Q19904197", "Q1138905", "Q440165", "Q787524", "Q13050046", "Q15748660", "Q6604140", "Q11400285", "Q20071151", "Q2912875", "Q786", "Q1999706", "Q11398056", "Q4486275", "Q3744158", "Q63524702", "Q38745473", "Q37996883", "Q29260670"]} \ No newline at end of file diff --git a/mock_test_data/predictions/mewslix/test-fa.json b/mock_test_data/predictions/mewslix/test-fa.json new file mode 100644 index 0000000..6a998c8 --- /dev/null +++ b/mock_test_data/predictions/mewslix/test-fa.json @@ -0,0 +1 @@ +{"d35cc57a7869168ddeb8143c1b2260f3": ["Q333972", "Q48270", "Q5254564", "Q76", "Q5947394", "Q3151708", "Q1756916", "Q63091766", "Q13104276", "Q5839704", "Q6598064", "Q1008989", "Q48762758", "Q55842144", "Q461358", "Q447087", "Q13640998", "Q535894", "Q223278", "Q3504372"]} \ No newline at end of file diff --git a/mock_test_data/predictions/mewslix/test-ja.json b/mock_test_data/predictions/mewslix/test-ja.json new file mode 100644 index 0000000..4683f3a --- /dev/null +++ b/mock_test_data/predictions/mewslix/test-ja.json @@ -0,0 +1 @@ +{"d0e7a9dd0359610c53bba176d702dfce": ["Q1210312", "Q3662301", "Q2877167", "Q13548902", "Q3458109", "Q65159649", "Q49892", "Q204547", "Q12699816", "Q372592", "Q1776619", "Q16633277", "Q1658454", "Q174691", "Q1053638", "Q23653996", "Q798074", "Q24939391", "Q8037644", "Q65967892"]} \ No newline at end of file diff --git a/mock_test_data/predictions/mewslix/test-pl.json b/mock_test_data/predictions/mewslix/test-pl.json new file mode 100644 index 0000000..ad18d54 --- /dev/null +++ b/mock_test_data/predictions/mewslix/test-pl.json @@ -0,0 +1 @@ +{"64232b8a3c3ee67f76f96ccd963b78f7": ["Q1033066", "Q565472", "Q11598441", "Q29522", "Q16027287", "Q1174348", "Q1052293", "Q16903684", "Q12860947", "Q48769622", "Q2606279", "Q7315521", "Q268776", "Q13621486", "Q1400430", "Q7124665", "Q11280748", "Q710911", "Q1362561", "Q34754"]} \ No newline at end of file diff --git a/mock_test_data/predictions/mewslix/test-ro.json b/mock_test_data/predictions/mewslix/test-ro.json new file mode 100644 index 0000000..635add1 --- /dev/null +++ b/mock_test_data/predictions/mewslix/test-ro.json @@ -0,0 +1 @@ +{"ebd92132adbb679fdd090503cd925f81": ["Q1144739", "Q5836568", "Q20582855", "Q1311", "Q711832", "Q185007", "Q311559", "Q50391138", "Q55418237", "Q5037965", "Q601712", "Q6654524", "Q615949", "Q980941", "Q5188638", "Q15060144", "Q6737309", "Q21670139", "Q1040955", "Q928053"]} \ No newline at end of file diff --git a/mock_test_data/predictions/mewslix/test-ta.json b/mock_test_data/predictions/mewslix/test-ta.json new file mode 100644 index 0000000..4ee474d --- /dev/null +++ b/mock_test_data/predictions/mewslix/test-ta.json @@ -0,0 +1 @@ +{"12760cb39680a822c3cd0c8495cf1b4b": ["Q22959171", "Q13385006", "Q608803", "Q3046191", "Q1750336", "Q15353797", "Q1695555", "Q124473", "Q836937", "Q3297349", "Q430687", "Q2181287", "Q11468", "Q20393369", "Q888226", "Q56477015", "Q22692651", "Q13829184", "Q2479497", "Q3207103"]} \ No newline at end of file diff --git a/mock_test_data/predictions/mewslix/test-tr.json b/mock_test_data/predictions/mewslix/test-tr.json new file mode 100644 index 0000000..a887e8f --- /dev/null +++ b/mock_test_data/predictions/mewslix/test-tr.json @@ -0,0 +1 @@ +{"9f39acb0fef259aaf24224fe41954f6c": ["Q11350542", "Q188447", "Q15905812", "Q15868", "Q6630136", "Q6734763", "Q105927", "Q258", "Q9181720", "Q313196", "Q4099359", "Q15567185", "Q587455", "Q190436", "Q5284896", "Q18709782", "Q16233625", "Q5246694", "Q11620425", "Q12568992"]} \ No newline at end of file diff --git a/mock_test_data/predictions/mewslix/test-uk.json b/mock_test_data/predictions/mewslix/test-uk.json new file mode 100644 index 0000000..e291854 --- /dev/null +++ b/mock_test_data/predictions/mewslix/test-uk.json @@ -0,0 +1 @@ +{"9f4dba86a6d21cfd246353403da46abd": ["Q524624", "Q3830755", "Q3800390", "Q508679", "Q20383186", "Q930701", "Q18682623", "Q16969424", "Q1899", "Q2320371", "Q266613", "Q2469647", "Q749794", "Q6241038", "Q5754881", "Q2879448", "Q1630799", "Q447", "Q628319", "Q25515301"]} \ No newline at end of file