Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions experiments/evals/exp1602b_lm_eval_selected.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2025 The Marin Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Run the selected LM Eval Harness tasks across a set of Marin, Qwen 2.5, OLMo 2, Llama 3, and OLMo 3 models.
"""

from collections.abc import Iterable
from dataclasses import replace

from experiments.evals.evals import default_eval
from experiments.evals.resource_configs import SINGLE_TPU_V5p_8_FULL
from experiments.evals.task_configs import LM_EVAL_HARNESS_SELECTED_TASKS
from experiments.models import (
llama_3_1_8b,
llama_3_70b,
marin_8b_base,
marin_32b_base,
olmo_2_base_32b,
olmo_2_base_8b,
olmo_3_32b,
olmo_3_7b,
qwen2_5_32b,
)
from marin.evaluation.evaluation_config import EvalTaskConfig
from marin.execution.executor import ExecutorStep, executor_main

MARIN_MODELS: tuple[ExecutorStep, ...] = (marin_8b_base, marin_32b_base)
QWEN_2_5_MODELS: tuple[ExecutorStep, ...] = (qwen2_5_32b, )
OLMO_2_MODELS: tuple[ExecutorStep, ...] = (olmo_2_base_8b, olmo_2_base_32b)
LLAMA_3_MODELS: tuple[ExecutorStep, ...] = (llama_3_1_8b, llama_3_70b)
OLMO_3_MODELS: tuple[ExecutorStep, ...] = (olmo_3_7b, olmo_3_32b)

ALL_MODEL_STEPS: tuple[ExecutorStep, ...] = (
# *MARIN_MODELS,
# *QWEN_2_5_MODELS,
# *OLMO_2_MODELS,
# *LLAMA_3_MODELS,
*OLMO_3_MODELS,
)


def _create_per_task_eval_steps(model_step: ExecutorStep, tasks: Iterable[EvalTaskConfig]) -> list[ExecutorStep]:
"""Return one evaluation step per LM Eval Harness task for a given model."""

per_task_steps: list[ExecutorStep] = []
for task in tasks:
eval_step = default_eval(
step=model_step,
resource_config=SINGLE_TPU_V5p_8_FULL,
evals=(task,),
discover_latest_checkpoint=False,
)
task_label = task.task_alias or task.name
# Make it obvious which harness task is running to simplify scheduling/debugging.
per_task_steps.append(replace(eval_step, name=f"{eval_step.name}/{task_label}"))

return per_task_steps


eval_steps: list[ExecutorStep] = []
for model_step in ALL_MODEL_STEPS:
eval_steps.extend(_create_per_task_eval_steps(model_step, LM_EVAL_HARNESS_SELECTED_TASKS))

if __name__ == "__main__":
# executor_main(steps=eval_steps)
for i in range(0, len(eval_steps), 4):
executor_main(steps=eval_steps[i : i + 4])
265 changes: 265 additions & 0 deletions experiments/evals/task_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,34 @@
)

# LM-Eval-Harness Tasks
LM_EVAL_HARNESS_SELECTED_TASKS = ( # multiple choice tasks
EvalTaskConfig("copa", 0),
EvalTaskConfig("mmlu", 5),
EvalTaskConfig("leaderboard_musr", 0),
EvalTaskConfig("anli_r1", 0),
EvalTaskConfig("anli_r2", 0),
EvalTaskConfig("anli_r3", 0),
EvalTaskConfig("truthfulqa_mc2", 6),
EvalTaskConfig("race", 0),
EvalTaskConfig("toxigen", 0),
EvalTaskConfig("agieval_lsat_ar", 3),
EvalTaskConfig("arc_easy", 10),
EvalTaskConfig("arc_challenge", 10),
EvalTaskConfig("leaderboard_bbh", 3),
EvalTaskConfig("boolq", 10),
EvalTaskConfig("commonsense_qa", 10),
EvalTaskConfig("leaderboard_gpqa", 0),
EvalTaskConfig("hellaswag", 10),
EvalTaskConfig("leaderboard_mmlu_pro", 5),
EvalTaskConfig("openbookqa", 0),
EvalTaskConfig("piqa", 10),
EvalTaskConfig("winogrande", 0),
EvalTaskConfig("wsc273", 0),
EvalTaskConfig("squadv2", 0),
EvalTaskConfig("minerva_math", 4),
)


# Reasoning and Logic Tasks
REASONING_TASKS = (
EvalTaskConfig("anli_r1", 0, task_alias="anli_r1_0shot"),
Expand Down Expand Up @@ -311,6 +339,243 @@
EvalTaskConfig("webqs", 0, task_alias="webqs_0shot"),
)

BELEBELE_TASKS = tuple(
EvalTaskConfig(task_name, 0, task_alias=alias)
for task_name, alias in (
("belebele_acm_Arab", "belebele_acm_Arab"),
("belebele_afr_Latn", "belebele_afr_Latn"),
("belebele_als_Latn", "belebele_als_Latn"),
("belebele_amh_Ethi", "belebele_amh_Ethi"),
("belebele_apc_Arab", "belebele_apc_Arab"),
("belebele_arb_Arab", "belebele_arb_Arab"),
("belebele_arb_Latn", "belebele_arb_Latn"),
("belebele_ars_Arab", "belebele_ars_Arab"),
("belebele_ary_Arab", "belebele_ary_Arab"),
("belebele_arz_Arab", "belebele_arz_Arab"),
("belebele_asm_Beng", "belebele_asm_Beng"),
("belebele_azj_Latn", "belebele_azj_Latn"),
("belebele_bam_Latn", "belebele_bam_Latn"),
("belebele_ben_Beng", "belebele_ben_Beng"),
("belebele_ben_Latn", "belebele_ben_Latn"),
("belebele_bod_Tibt", "belebele_bod_Tibt"),
("belebele_bul_Cyrl", "belebele_bul_Cyrl"),
("belebele_cat_Latn", "belebele_cat_Latn"),
("belebele_ceb_Latn", "belebele_ceb_Latn"),
("belebele_ces_Latn", "belebele_ces_Latn"),
("belebele_ckb_Arab", "belebele_ckb_Arab"),
("belebele_dan_Latn", "belebele_dan_Latn"),
("belebele_deu_Latn", "belebele_deu_Latn"),
("belebele_ell_Grek", "belebele_ell_Grek"),
("belebele_eng_Latn", "belebele_eng_Latn"),
("belebele_est_Latn", "belebele_est_Latn"),
("belebele_eus_Latn", "belebele_eus_Latn"),
("belebele_fin_Latn", "belebele_fin_Latn"),
("belebele_fra_Latn", "belebele_fra_Latn"),
("belebele_fuv_Latn", "belebele_fuv_Latn"),
("belebele_gaz_Latn", "belebele_gaz_Latn"),
("belebele_grn_Latn", "belebele_grn_Latn"),
("belebele_guj_Gujr", "belebele_guj_Gujr"),
("belebele_hat_Latn", "belebele_hat_Latn"),
("belebele_hau_Latn", "belebele_hau_Latn"),
("belebele_heb_Hebr", "belebele_heb_Hebr"),
("belebele_hin_Deva", "belebele_hin_Deva"),
("belebele_hin_Latn", "belebele_hin_Latn"),
("belebele_hrv_Latn", "belebele_hrv_Latn"),
("belebele_hun_Latn", "belebele_hun_Latn"),
("belebele_hye_Armn", "belebele_hye_Armn"),
("belebele_ibo_Latn", "belebele_ibo_Latn"),
("belebele_ilo_Latn", "belebele_ilo_Latn"),
("belebele_ind_Latn", "belebele_ind_Latn"),
("belebele_isl_Latn", "belebele_isl_Latn"),
("belebele_ita_Latn", "belebele_ita_Latn"),
("belebele_jav_Latn", "belebele_jav_Latn"),
("belebele_jpn_Jpan", "belebele_jpn_Jpan"),
("belebele_kac_Latn", "belebele_kac_Latn"),
("belebele_kan_Knda", "belebele_kan_Knda"),
("belebele_kat_Geor", "belebele_kat_Geor"),
("belebele_kaz_Cyrl", "belebele_kaz_Cyrl"),
("belebele_kea_Latn", "belebele_kea_Latn"),
("belebele_khk_Cyrl", "belebele_khk_Cyrl"),
("belebele_khm_Khmr", "belebele_khm_Khmr"),
("belebele_kin_Latn", "belebele_kin_Latn"),
("belebele_kir_Cyrl", "belebele_kir_Cyrl"),
("belebele_kor_Hang", "belebele_kor_Hang"),
("belebele_lao_Laoo", "belebele_lao_Laoo"),
("belebele_lin_Latn", "belebele_lin_Latn"),
("belebele_lit_Latn", "belebele_lit_Latn"),
("belebele_lug_Latn", "belebele_lug_Latn"),
("belebele_luo_Latn", "belebele_luo_Latn"),
("belebele_lvs_Latn", "belebele_lvs_Latn"),
("belebele_mal_Mlym", "belebele_mal_Mlym"),
("belebele_mar_Deva", "belebele_mar_Deva"),
("belebele_mkd_Cyrl", "belebele_mkd_Cyrl"),
("belebele_mlt_Latn", "belebele_mlt_Latn"),
("belebele_mri_Latn", "belebele_mri_Latn"),
("belebele_mya_Mymr", "belebele_mya_Mymr"),
("belebele_nld_Latn", "belebele_nld_Latn"),
("belebele_nob_Latn", "belebele_nob_Latn"),
("belebele_npi_Deva", "belebele_npi_Deva"),
("belebele_npi_Latn", "belebele_npi_Latn"),
("belebele_nso_Latn", "belebele_nso_Latn"),
("belebele_nya_Latn", "belebele_nya_Latn"),
("belebele_ory_Orya", "belebele_ory_Orya"),
("belebele_pan_Guru", "belebele_pan_Guru"),
("belebele_pbt_Arab", "belebele_pbt_Arab"),
("belebele_pes_Arab", "belebele_pes_Arab"),
("belebele_plt_Latn", "belebele_plt_Latn"),
("belebele_pol_Latn", "belebele_pol_Latn"),
("belebele_por_Latn", "belebele_por_Latn"),
("belebele_ron_Latn", "belebele_ron_Latn"),
("belebele_rus_Cyrl", "belebele_rus_Cyrl"),
("belebele_shn_Mymr", "belebele_shn_Mymr"),
("belebele_sin_Latn", "belebele_sin_Latn"),
("belebele_sin_Sinh", "belebele_sin_Sinh"),
("belebele_slk_Latn", "belebele_slk_Latn"),
("belebele_slv_Latn", "belebele_slv_Latn"),
("belebele_sna_Latn", "belebele_sna_Latn"),
("belebele_snd_Arab", "belebele_snd_Arab"),
("belebele_som_Latn", "belebele_som_Latn"),
("belebele_sot_Latn", "belebele_sot_Latn"),
("belebele_spa_Latn", "belebele_spa_Latn"),
("belebele_srp_Cyrl", "belebele_srp_Cyrl"),
("belebele_ssw_Latn", "belebele_ssw_Latn"),
("belebele_sun_Latn", "belebele_sun_Latn"),
("belebele_swe_Latn", "belebele_swe_Latn"),
("belebele_swh_Latn", "belebele_swh_Latn"),
("belebele_tam_Taml", "belebele_tam_Taml"),
("belebele_tel_Telu", "belebele_tel_Telu"),
("belebele_tgk_Cyrl", "belebele_tgk_Cyrl"),
("belebele_tgl_Latn", "belebele_tgl_Latn"),
("belebele_tha_Thai", "belebele_tha_Thai"),
("belebele_tir_Ethi", "belebele_tir_Ethi"),
("belebele_tsn_Latn", "belebele_tsn_Latn"),
("belebele_tso_Latn", "belebele_tso_Latn"),
("belebele_tur_Latn", "belebele_tur_Latn"),
("belebele_ukr_Cyrl", "belebele_ukr_Cyrl"),
("belebele_urd_Arab", "belebele_urd_Arab"),
("belebele_urd_Latn", "belebele_urd_Latn"),
("belebele_uzn_Latn", "belebele_uzn_Latn"),
("belebele_vie_Latn", "belebele_vie_Latn"),
("belebele_war_Latn", "belebele_war_Latn"),
("belebele_wol_Latn", "belebele_wol_Latn"),
("belebele_xho_Latn", "belebele_xho_Latn"),
("belebele_yor_Latn", "belebele_yor_Latn"),
("belebele_zho_Hans", "belebele_zho_Hans"),
("belebele_zho_Hant", "belebele_zho_Hant"),
("belebele_zsm_Latn", "belebele_zsm_Latn"),
("belebele_zul_Latn", "belebele_zul_Latn"),
)
)

FEW_SHOT_OG_MULTILINGUAL_TASKS = tuple(
EvalTaskConfig(task_name, num_fewshot=5, task_alias=alias)
for task_name, alias in (
("include_base_44_albanian_few_shot_og", "include_base_44_albanian_few_shot_og_5_shot"),
("include_base_44_arabic_few_shot_og", "include_base_44_arabic_few_shot_og_5_shot"),
("include_base_44_armenian_few_shot_og", "include_base_44_armenian_few_shot_og_5_shot"),
("include_base_44_azerbaijani_few_shot_og", "include_base_44_azerbaijani_few_shot_og_5_shot"),
("include_base_44_basque_few_shot_og", "include_base_44_basque_few_shot_og_5_shot"),
("include_base_44_belarusian_few_shot_og", "include_base_44_belarusian_few_shot_og_5_shot"),
("include_base_44_bengali_few_shot_og", "include_base_44_bengali_few_shot_og_5_shot"),
("include_base_44_bulgarian_few_shot_og", "include_base_44_bulgarian_few_shot_og_5_shot"),
("include_base_44_chinese_few_shot_og", "include_base_44_chinese_few_shot_og_5_shot"),
("include_base_44_croatian_few_shot_og", "include_base_44_croatian_few_shot_og_5_shot"),
("include_base_44_dutch_few_shot_og", "include_base_44_dutch_few_shot_og_5_shot"),
# ("include_base_44_estonian_few_shot_og", "include_base_44_estonian_few_shot_og_5_shot"),
# Disabled: Estonian test domains top out at 2 docs, so 5-shot sampling fails.
("include_base_44_finnish_few_shot_og", "include_base_44_finnish_few_shot_og_5_shot"),
("include_base_44_french_few_shot_og", "include_base_44_french_few_shot_og_5_shot"),
("include_base_44_georgian_few_shot_og", "include_base_44_georgian_few_shot_og_5_shot"),
("include_base_44_german_few_shot_og", "include_base_44_german_few_shot_og_5_shot"),
("include_base_44_greek_few_shot_og", "include_base_44_greek_few_shot_og_5_shot"),
("include_base_44_hebrew_few_shot_og", "include_base_44_hebrew_few_shot_og_5_shot"),
("include_base_44_hindi_few_shot_og", "include_base_44_hindi_few_shot_og_5_shot"),
("include_base_44_hungarian_few_shot_og", "include_base_44_hungarian_few_shot_og_5_shot"),
("include_base_44_indonesian_few_shot_og", "include_base_44_indonesian_few_shot_og_5_shot"),
("include_base_44_italian_few_shot_og", "include_base_44_italian_few_shot_og_5_shot"),
("include_base_44_japanese_few_shot_og", "include_base_44_japanese_few_shot_og_5_shot"),
("include_base_44_kazakh_few_shot_og", "include_base_44_kazakh_few_shot_og_5_shot"),
("include_base_44_korean_few_shot_og", "include_base_44_korean_few_shot_og_5_shot"),
("include_base_44_lithuanian_few_shot_og", "include_base_44_lithuanian_few_shot_og_5_shot"),
("include_base_44_malay_few_shot_og", "include_base_44_malay_few_shot_og_5_shot"),
# ("include_base_44_malayalam_few_shot_og", "include_base_44_malayalam_few_shot_og_5_shot"),
# Disabled: Malayalam domains have <=4 test docs, so 5-shot sampling fails.
("include_base_44_nepali_few_shot_og", "include_base_44_nepali_few_shot_og_5_shot"),
("include_base_44_north macedonian_few_shot_og", "include_base_44_north macedonian_few_shot_og_5_shot"),
("include_base_44_persian_few_shot_og", "include_base_44_persian_few_shot_og_5_shot"),
# ("include_base_44_polish_few_shot_og", "include_base_44_polish_few_shot_og_5_shot"),
# Disabled: Polish domains have <=4 test docs, so 5-shot sampling fails.
("include_base_44_portuguese_few_shot_og", "include_base_44_portuguese_few_shot_og_5_shot"),
("include_base_44_russian_few_shot_og", "include_base_44_russian_few_shot_og_5_shot"),
("include_base_44_serbian_few_shot_og", "include_base_44_serbian_few_shot_og_5_shot"),
("include_base_44_spanish_few_shot_og", "include_base_44_spanish_few_shot_og_5_shot"),
("include_base_44_tagalog_few_shot_og", "include_base_44_tagalog_few_shot_og_5_shot"),
("include_base_44_tamil_few_shot_og", "include_base_44_tamil_few_shot_og_5_shot"),
("include_base_44_telugu_few_shot_og", "include_base_44_telugu_few_shot_og_5_shot"),
("include_base_44_turkish_few_shot_og", "include_base_44_turkish_few_shot_og_5_shot"),
("include_base_44_ukrainian_few_shot_og", "include_base_44_ukrainian_few_shot_og_5_shot"),
# ("include_base_44_urdu_few_shot_og", "include_base_44_urdu_few_shot_og_5_shot"),
# Disabled: Urdu domains have <=3 test docs, so 5-shot sampling fails.
("include_base_44_uzbek_few_shot_og", "include_base_44_uzbek_few_shot_og_5_shot"),
("include_base_44_vietnamese_few_shot_og", "include_base_44_vietnamese_few_shot_og_5_shot"),
)
)

MGSM_MULTILINGUAL_TASKS = (
EvalTaskConfig("mgsm_direct_de", 0, task_alias="mgsm_direct_de_0shot"),
EvalTaskConfig("mgsm_direct_es", 0, task_alias="mgsm_direct_es_0shot"),
EvalTaskConfig("mgsm_direct_fr", 0, task_alias="mgsm_direct_fr_0shot"),
EvalTaskConfig("mgsm_direct_ja", 0, task_alias="mgsm_direct_ja_0shot"),
EvalTaskConfig("mgsm_direct_ru", 0, task_alias="mgsm_direct_ru_0shot"),
EvalTaskConfig("mgsm_direct_zh", 0, task_alias="mgsm_direct_zh_0shot"),
)

XSTORYCLOZE_MULTILINGUAL_TASKS = (
EvalTaskConfig("xstorycloze_ar", 0, task_alias="xstorycloze_ar_0shot"),
EvalTaskConfig("xstorycloze_es", 0, task_alias="xstorycloze_es_0shot"),
EvalTaskConfig("xstorycloze_id", 0, task_alias="xstorycloze_id_0shot"),
EvalTaskConfig("xstorycloze_ru", 0, task_alias="xstorycloze_ru_0shot"),
EvalTaskConfig("xstorycloze_zh", 0, task_alias="xstorycloze_zh_0shot"),
)

MMMLU_MULTILINGUAL_TASKS = tuple(
EvalTaskConfig(f"mmmlu_{slug}", 5, task_alias=f"mmmlu_{slug}_5shot")
for slug in (
"ar_xy",
"bn_bd",
"de_de",
"es_la",
"fr_fr",
"hi_in",
"id_id",
"it_it",
"ja_jp",
"ko_kr",
"pt_br",
"sw_ke",
"yo_ng",
"zh_cn",
)
)

MULTILINGUAL_LM_EVAL_LOGPROB_TASKS = (
EvalTaskConfig("cmmlu", 0, task_alias="cmmlu_0shot_ziqing"),
# FEW_SHOT_OG_MULTILINGUAL_TASKS
# + BELEBELE_TASKS
# + XSTORYCLOZE_MULTILINGUAL_TASKS
# MMMLU_MULTILINGUAL_TASKS
# + (
#
# EvalTaskConfig("kmmlu", 0, task_alias="kmmlu_0shot"),
# EvalTaskConfig("lm_syneval", 0, task_alias="lm_syneval_0shot"),
# EvalTaskConfig("zhoblimp", 0, task_alias="zhoblimp_0shot"),
# EvalTaskConfig("turblimp_core", 0, task_alias="turblimp_core_0shot"),
# EvalTaskConfig("blimp_nl", 0, task_alias="blimp_nl_0shot"),
# )
)

MULTILINGUAL_LM_EVAL_GENERATIVE_TASKS = MGSM_MULTILINGUAL_TASKS


def convert_to_levanter_task_config(tasks: Sequence[EvalTaskConfig]) -> list[TaskConfig]:
"""
Expand Down
Loading
Loading