diff --git a/docs/references/default-steps.md b/docs/references/default-steps.md index 1f491067e1..628850c8d4 100644 --- a/docs/references/default-steps.md +++ b/docs/references/default-steps.md @@ -27,12 +27,6 @@ In general, you should reach for the default steps before writing your own. ::: experiments.defaults.simulated_epoching_train -## Scaling Law Prediction - -::: marin.scaling_laws.create_ladder_suite.scaling_law_suite - -::: experiments.defaults.default_scaling_law_pred - ## Evaluation ::: experiments.evals.evals.default_eval diff --git a/experiments/defaults.py b/experiments/defaults.py index 809f159d79..cfd7f78255 100644 --- a/experiments/defaults.py +++ b/experiments/defaults.py @@ -46,7 +46,6 @@ CORE_TASKS, MMLU_TASKS, convert_to_levanter_task_config, - convert_to_task_metrics, ) from experiments.llama import compute_num_parameters, llama_8b from experiments.paloma import paloma_tokenized @@ -59,7 +58,6 @@ InputName, VersionedValue, ensure_versioned, - get_executor_step, this_output_path, unwrap_versioned_value, ) @@ -72,7 +70,6 @@ tokenize, ) from marin.processing.tokenize.tokenize import HfTokenizeConfig, TokenizeConfigBase -from marin.scaling_laws.scaling_laws import ScalingLawConfig, run_scaling_law_analysis from marin.training.training import ( TrainLmOnPodConfig, run_levanter_train_lm, @@ -637,41 +634,3 @@ def _get_tokenizer_for_train(tokenized: InputName | ExecutorStep | LMMixtureData raise ValueError(f"Could not determine tokenizer from {tokenized}") return tokenizer - - -def default_scaling_law_pred( - ladder_runs: Sequence[ExecutorStep | InputName | str], - pred_run: ExecutorStep | InputName | str | None = None, - task_losses: Sequence[str] = ("eval/paloma/c4_en/bpb",), - task_accuracies: Sequence[str] | Sequence[EvalTaskConfig] | None = None, -): - """ - Given a suite of small models, predict the performance on a number of (N, D) values. - """ - # get the executor steps or run IDs for the ladder runs and the pred run - ladder_steps_or_ids = [get_executor_step(run) if not isinstance(run, str) else run for run in ladder_runs] - - pred_run_or_id = None - if pred_run: - pred_run_or_id = get_executor_step(pred_run) if not isinstance(pred_run, str) else pred_run - - # convert the task accuracies to strings if they are `EvalTaskConfig`s - if task_accuracies is not None: - task_accuracies = convert_to_task_metrics(task_accuracies, metric="acc") - - if pred_run_or_id: - name = pred_run_or_id if isinstance(pred_run_or_id, str) else pred_run_or_id.name - else: - name = "projection" - - return ExecutorStep( - name=f"""scaling_laws/{name}""", - fn=run_scaling_law_analysis, - config=ScalingLawConfig( - name=name, - ladder_model_steps=ladder_steps_or_ids, - pred_model_step=pred_run_or_id, - task_losses=task_losses, - task_accuracies=task_accuracies, - ), - ) diff --git a/experiments/exp1600_perpcorr.py b/experiments/exp1600_perpcorr.py index 1ab8d39ee6..c0cc140def 100644 --- a/experiments/exp1600_perpcorr.py +++ b/experiments/exp1600_perpcorr.py @@ -24,7 +24,7 @@ from experiments.evals.evals import evaluate_levanter_lm_evaluation_harness from experiments.evals.task_configs import EvalTaskConfig -from experiments.isoflop_sweep import generate_isoflop_sweep +from experiments.isoflop_sweep import MARIN_2025_RECIPE, create_isoflop_sweep_steps from experiments.llama import llama3_tokenizer from experiments.models import ModelConfig as HFModelConfig, download_model_step from experiments.paloma import paloma_tokenized @@ -56,22 +56,22 @@ @lru_cache(maxsize=1) def build_steps(): steps = [] - isoflop_steps, isoflop_metadatas = generate_isoflop_sweep( + isoflop_steps, isoflop_candidates = create_isoflop_sweep_steps( nemotron_mix, experiment_name="nemo-wider-depth-adapt", + recipe=MARIN_2025_RECIPE, ) - for isoflop_step, isoflop_metadata in zip(isoflop_steps, isoflop_metadatas, strict=False): + for isoflop_step, candidate in zip(isoflop_steps, isoflop_candidates, strict=False): experiment_name = isoflop_step.name.split("/")[-1] paloma_tokenized_dict = paloma_tokenized(tokenizer=llama3_tokenizer) uncheatable_eval_tokenized_dict = uncheatable_eval_tokenized(tokenizer=llama3_tokenizer) eval_data = mixture_for_evaluation(paloma_tokenized_dict | uncheatable_eval_tokenized_dict) - budget, hidden_size, num_layers, batch_size, train_steps = isoflop_metadata wandb_tags = [ - f"FLOPs={budget:.1e}", - f"d={hidden_size}", - f"L={num_layers}", - f"B={batch_size}", - f"steps={train_steps}", + f"FLOPs={candidate.flops_budget:.1e}", + f"d={candidate.hidden_size}", + f"L={candidate.num_layers}", + f"B={candidate.batch_size}", + f"steps={candidate.train_steps}", ] model_config = isoflop_step.config.train_config.model checkpoint_path = output_path_of(isoflop_step) diff --git a/experiments/exp1603_subgroup_evals.py b/experiments/exp1603_subgroup_evals.py index 26066211f8..ba0ddcbe6f 100644 --- a/experiments/exp1603_subgroup_evals.py +++ b/experiments/exp1603_subgroup_evals.py @@ -24,6 +24,10 @@ from experiments.models import ModelConfig, download_model_step from marin.execution.executor import executor_main, output_path_of, versioned from marin.evaluation.log_probs import default_lm_log_probs +from marin.processing.tokenize import get_vocab_size_for_tokenizer + +# Vocab size for building model configs +VOCAB_SIZE = get_vocab_size_for_tokenizer("stanford-crfm/marin-tokenizer") # This is painfully slow to run in dry run mode # nodryrun @@ -40,8 +44,10 @@ def create_eval_steps() -> list: steps = [] dist_eval = distributional_eval_sets(llama3_tokenizer) - for model, metadata in list(zip(*MARIN_SCALING_SUITES["nemotron"], strict=False)): - name = f"marin-nemo-{metadata[0]}C-{metadata[-3] * metadata[-2] * 4096}T-{metadata[1]}W-{metadata[2]}D" + for model, candidate in list(zip(*MARIN_SCALING_SUITES["nemotron"], strict=False)): + total_tokens = int(candidate.tokens) + params = candidate.model_config.total_trainable_params(VOCAB_SIZE) + name = f"marin-nemo-{candidate.flops_budget:.0e}C-{total_tokens}T-N{params:.0e}" step = evaluate_levanter_lm_evaluation_harness( model_name=name, @@ -51,9 +57,10 @@ def create_eval_steps() -> list: ) steps.append(step) + model_config = candidate.model_config logprobs_step = default_lm_log_probs( output_path_of(model).cd("checkpoints"), - metadata[-1], + model_config, dist_eval, resource_config=ResourceConfig.with_tpu("v5p-8"), checkpoint_is_hf=False, @@ -62,8 +69,10 @@ def create_eval_steps() -> list: steps.append(logprobs_step) - for model, metadata in list(zip(*MARIN_SCALING_SUITES["common_pile"], strict=False)): - name = f"marin-comma-{metadata[0]}C-{metadata[-3] * metadata[-2] * 4096}T-{metadata[1]}W-{metadata[2]}D" + for model, candidate in list(zip(*MARIN_SCALING_SUITES["common_pile"], strict=False)): + total_tokens = int(candidate.tokens) + params = candidate.model_config.total_trainable_params(VOCAB_SIZE) + name = f"marin-comma-{candidate.flops_budget:.0e}C-{total_tokens}T-N{params:.0e}" step = evaluate_levanter_lm_evaluation_harness( model_name=name, @@ -73,9 +82,10 @@ def create_eval_steps() -> list: ) steps.append(step) + model_config = candidate.model_config logprobs_step = default_lm_log_probs( output_path_of(model).cd("checkpoints"), - metadata[-1], + model_config, dist_eval, resource_config=ResourceConfig.with_tpu("v5p-8"), checkpoint_is_hf=False, @@ -84,8 +94,10 @@ def create_eval_steps() -> list: steps.append(logprobs_step) - for model, metadata in list(zip(*MARIN_SCALING_SUITES["dclm-default"], strict=False)): - name = f"marin-dclm-{metadata[0]}C-{metadata[-3] * metadata[-2] * 4096}T-{metadata[1]}W-{metadata[2]}D" + for model, candidate in list(zip(*MARIN_SCALING_SUITES["dclm-default"], strict=False)): + total_tokens = int(candidate.tokens) + params = candidate.model_config.total_trainable_params(VOCAB_SIZE) + name = f"marin-dclm-{candidate.flops_budget:.0e}C-{total_tokens}T-N{params:.0e}" step = evaluate_levanter_lm_evaluation_harness( model_name=name, @@ -95,16 +107,17 @@ def create_eval_steps() -> list: ) steps.append(step) - logprobs_step = default_lm_log_probs( - output_path_of(model).cd("checkpoints"), - metadata[-1], - dist_eval, - resource_config=ResourceConfig.with_tpu("v5p-8"), - checkpoint_is_hf=False, - name=versioned(f"{name}-DistRobust-ICE-logprobs"), - ) + model_config = candidate.model_config + logprobs_step = default_lm_log_probs( + output_path_of(model).cd("checkpoints"), + model_config, + dist_eval, + resource_config=ResourceConfig.with_tpu("v5p-8"), + checkpoint_is_hf=False, + name=versioned(f"{name}-DistRobust-ICE-logprobs"), + ) - steps.append(logprobs_step) + steps.append(logprobs_step) baselines = [ ("allenai/OLMo-2-1124-7B", "stage2-ingredient3-step8000-tokens34B"), diff --git a/experiments/exp1752_simulated_epoching.py b/experiments/exp1752_simulated_epoching.py deleted file mode 100644 index a88bde0149..0000000000 --- a/experiments/exp1752_simulated_epoching.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright 2025 The Marin Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Scaling law comparison between Stack v2 datasets and StarCoderData with simulated epoching.""" - -import dataclasses -import logging -from collections.abc import Sequence - -from levanter.data.text import LMMixtureDatasetConfig -from levanter.models.llama import LlamaConfig - -from experiments.common_pile.tokenize_common_pile import stackv2, stackv2_edu_filtered -from experiments.defaults import default_tokenize, simulated_epoching_train -from experiments.evals.task_configs import CORE_TASKS -from experiments.llama import llama3_tokenizer, llama_1_4b -from experiments.pretraining_datasets.dclm import dclm_components_llama3 -from experiments.simple_train_config import SimpleTrainConfig -from fray.cluster import ResourceConfig -from marin.execution.executor import ExecutorStep, InputName, executor_main - -TPU_TYPE = "v5p-8" -TAG = ["exp1752", "simulated_epoching"] - -STACK_V2_SWEEP_NAME = "exp1752-stack-v2-sim" -STACK_V2_EDU_SWEEP_NAME = "exp1752-stack-v2-edu-sim" -STARCODER_SWEEP_NAME = "exp1752-starcoderdata-sim" - -SIMULATED_TARGET_BUDGET_TOKENS = 15_000_000_000_000 # 15T tokens to mimic full-budget epoching behaviour - -training_config = SimpleTrainConfig( - resources=ResourceConfig.with_tpu(TPU_TYPE, slice_count=1), - train_batch_size=256, - learning_rate=1e-3, - weight_decay=0.1, - num_train_steps=200000, - warmup=1000, - decay=0.0, - lr_schedule="constant", - ema_beta=0.995, - steps_per_eval=500, - steps_per_task_eval=500, -) - - -def simulated_scaling_law_suite( - sweep_name: str, - tokenized: InputName | ExecutorStep | LMMixtureDatasetConfig, - *, - widths: Sequence[int] = (512, 768, 1024, 1536, 2048), - base_model_config: LlamaConfig = llama_1_4b, - tags: Sequence[str] = (), - intermediate_scale: float = 4, - training_config: SimpleTrainConfig = training_config, - base_lr: float = 3e-4 * 4096, - max_lr: float = 5e-3, - target_budget: int = SIMULATED_TARGET_BUDGET_TOKENS, -) -> Sequence[ExecutorStep]: - """Mirror scaling_law_suite but replace training with simulated epoching.""" - - steps: list[ExecutorStep] = [] - for width in widths: - intermediate_dim = _round_to_multiple(intermediate_scale * width, 128) - head_size = 128 # keeping this 128 means we can use splash attention - num_heads = width // head_size - num_kv_heads = min(num_heads, 8) - assert num_heads * head_size == width, f"Number of heads must divide width: {width} % {head_size} != 0" - - if num_heads % num_kv_heads != 0: - num_kv_heads = num_heads - - model_config = dataclasses.replace( - base_model_config, - hidden_dim=width, - intermediate_dim=intermediate_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - ) - - learning_rate = min(base_lr / width, max_lr) - lr_training_config = dataclasses.replace(training_config, learning_rate=learning_rate) - - logging.info(f"Creating simulated epoching step for {sweep_name}-{width} with lr {learning_rate}") - - steps.append( - simulated_epoching_train( - name=f"{sweep_name}-{width}", - tokenized=tokenized, - model_config=model_config, - train_config=lr_training_config, - target_budget=target_budget, - tags=tags, - eval_harness_tasks=CORE_TASKS, - ) - ) - return steps - - -def _round_to_multiple(x: float, multiple: int) -> int: - return int(multiple * round(x / multiple)) - - -stackv2_tokenized = default_tokenize( - name="common_pile_stackv2", - dataset=stackv2 / "documents", - tokenizer=llama3_tokenizer, -) - -stackv2_edu_tokenized = default_tokenize( - name="common_pile_stackv2_edu", - dataset=stackv2_edu_filtered, - tokenizer=llama3_tokenizer, -) - -stackv2_suite = simulated_scaling_law_suite( - sweep_name=STACK_V2_SWEEP_NAME, - tokenized=stackv2_tokenized, - tags=[*TAG, "stackv2"], - intermediate_scale=4, - training_config=training_config, -) - -stackv2_edu_suite = simulated_scaling_law_suite( - sweep_name=STACK_V2_EDU_SWEEP_NAME, - tokenized=stackv2_edu_tokenized, - tags=[*TAG, "stackv2_edu"], - intermediate_scale=4, - training_config=training_config, -) - -starcoder_suite = simulated_scaling_law_suite( - sweep_name=STARCODER_SWEEP_NAME, - tokenized=dclm_components_llama3["starcoderdata"], - tags=[*TAG, "starcoderdata"], - intermediate_scale=4, - training_config=training_config, -) - -if __name__ == "__main__": - executor_main( - steps=[ - *stackv2_suite, - *stackv2_edu_suite, - *starcoder_suite, - ], - description="Scaling law sweeps comparing Stack v2 with StarCoderData using simulated epoching.", - ) diff --git a/experiments/exp1752_stackv2_vs_starcoder.py b/experiments/exp1752_stackv2_vs_starcoder.py deleted file mode 100644 index da8b6c3746..0000000000 --- a/experiments/exp1752_stackv2_vs_starcoder.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2025 The Marin Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Scaling law comparison between Stack v2 datasets and StarCoderData.""" - -from experiments.common_pile.tokenize_common_pile import stackv2, stackv2_edu_filtered -from experiments.defaults import default_tokenize -from experiments.llama import llama3_tokenizer -from experiments.pretraining_datasets.dclm import dclm_components_llama3 -from experiments.simple_train_config import SimpleTrainConfig -from fray.cluster import ResourceConfig -from marin.execution.executor import executor_main -from marin.scaling_laws.create_ladder_suite import scaling_law_suite - -TPU_TYPE = "v5p-8" -TAG = ["exp1752_stackv2_vs_starcoder"] - -STACK_V2_SWEEP_NAME = "exp1752-stack-v2" -STACK_V2_EDU_SWEEP_NAME = "exp1752-stack-v2-edu" -STARCODER_SWEEP_NAME = "exp1752-starcoderdata" - -training_config = SimpleTrainConfig( - resources=ResourceConfig.with_tpu(TPU_TYPE, slice_count=1), - train_batch_size=256, - learning_rate=1e-3, - weight_decay=0.1, - num_train_steps=200000, - warmup=1000, - decay=0.0, - lr_schedule="constant", - ema_beta=0.995, - steps_per_eval=500, - steps_per_task_eval=500, -) - - -stackv2_tokenized = default_tokenize( - name="common_pile_stackv2", - dataset=stackv2 / "documents", - tokenizer=llama3_tokenizer, -) - -stackv2_edu_tokenized = default_tokenize( - name="common_pile_stackv2_edu", - dataset=stackv2_edu_filtered, - tokenizer=llama3_tokenizer, -) - -stackv2_suite = scaling_law_suite( - sweep_name=STACK_V2_SWEEP_NAME, - tokenized=stackv2_tokenized, - tags=[*TAG, "stackv2"], - intermediate_scale=4, - training_config=training_config, -) - -stackv2_edu_suite = scaling_law_suite( - sweep_name=STACK_V2_EDU_SWEEP_NAME, - tokenized=stackv2_edu_tokenized, - tags=[*TAG, "stackv2_edu"], - intermediate_scale=4, - training_config=training_config, -) - -starcoder_suite = scaling_law_suite( - sweep_name=STARCODER_SWEEP_NAME, - tokenized=dclm_components_llama3["starcoderdata"], - tags=[*TAG, "starcoderdata"], - intermediate_scale=4, - training_config=training_config, -) - -if __name__ == "__main__": - executor_main( - steps=[ - *stackv2_suite, - *stackv2_edu_suite, - *starcoder_suite, - ], - description="Scaling law sweeps comparing Stack v2 with StarCoderData.", - ) diff --git a/experiments/exp2166_scaling_ladder_analysis.py b/experiments/exp2166_scaling_ladder_analysis.py new file mode 100644 index 0000000000..1ff51d1742 --- /dev/null +++ b/experiments/exp2166_scaling_ladder_analysis.py @@ -0,0 +1,276 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Exp2166: Scaling Ladder Analysis for Nemotron. + +This experiment runs scaling ladder analysis on the isoflop training sweeps +for the Nemotron (nemo-wider-depth-adapt) dataset. + +The scaling ladder: +1. Fits scaling laws from IsoFLOP sweep data to find compute-optimal configurations +2. Generates visualization plots (isoflop curves and scaling fit plots) +3. Optionally trains compute-optimal models at larger target budgets +""" + +import json +import logging +import os +from dataclasses import dataclass, replace +from datetime import timedelta + +import fsspec +import jmp +from fray.cluster import ResourceConfig +from haliax.partitioning import ResourceAxis +from levanter.checkpoint import CheckpointerConfig +from levanter.data.text import LMDatasetSourceConfig, LMMixtureDatasetConfig +from levanter.main import train_lm +from levanter.tracker.wandb import WandbConfig +from levanter.trainer import TrainerConfig +from levanter.utils.mesh import MeshConfig + +from experiments.defaults import default_validation_sets +from experiments.isoflop_sweep import ( + IsoFlopAnalysisConfig, + MARIN_2025_RECIPE, + MARIN_SCALING_SUITES, + nemotron_mix, + run_isoflop_analysis_step, +) +from experiments.llama import llama3_tokenizer +from marin.execution.executor import ExecutorStep, executor_main, this_output_path +from marin.processing.tokenize import step_to_lm_mixture_component +from marin.scaling_laws import ScalingFit, predict_optimal_config +from marin.scaling_laws.tpu_utils import pick_v5p_type, HBM_PER_CHIP_GIB +from marin.training.training import TrainLmOnPodConfig, run_levanter_train_lm + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Get training steps from the isoflop sweep +nemotron_training, _ = MARIN_SCALING_SUITES["nemotron"] + +# --- Configuration --- +TARGET_BUDGETS: list[float] = [1e18, 3e18, 6e18, 1e19, 3e19, 6e19, 1e20, 1e21, 1e22, 1e23, 1e24] +EXPERIMENT_NAME = "exp2166-scaling-ladder-nemotron-validation" +LABEL = "nemo-wider-depth-adapt" +SEQ_LEN = 4096 +MAX_TPU_TYPE = "v5p-64" # Cap TPU size; use gradient accumulation for larger models + + +@dataclass(frozen=True) +class OptimalTrainingConfig: + """Config for training a compute-optimal model based on scaling law analysis.""" + + analysis_output_path: str + """Path to the analysis output containing scaling fits.""" + + target_budget: float + """Target compute budget in FLOPs.""" + + label: str + """Dataset/experiment label to use for scaling fit lookup.""" + + output_path: str + """Output path for checkpoints and logs.""" + + tokenized: LMMixtureDatasetConfig + """Tokenized dataset for training. Executor will resolve InputName and unwrap VersionedValue.""" + + validation_configs: dict[str, LMDatasetSourceConfig] | None = None + """Validation set configs. Passed through config so executor resolves InputName paths.""" + + +def run_optimal_training(config: OptimalTrainingConfig) -> None: + """Run compute-optimal training at the given budget. + + Reads scaling fits from analysis output, predicts optimal config, + builds training config, and runs training directly. + """ + result_path = os.path.join(config.analysis_output_path, "isoflop_analysis_result.json") + fs, _, _ = fsspec.get_fs_token_paths(result_path) + + with fs.open(result_path, "r") as f: + analysis_result = json.load(f) + + scaling_fits: dict[str, ScalingFit] = {} + for key, value in analysis_result["scaling_fits"].items(): + if len(value) != 2: + raise ValueError(f"Expected 2 scaling fit values for '{key}', got {len(value)}") + scaling_fits[key] = ScalingFit(float(value[0]), float(value[1])) + + candidate = predict_optimal_config( + scaling_fits=scaling_fits, + target_flops=config.target_budget, + label=config.label, + recipe=MARIN_2025_RECIPE, + seq_len=SEQ_LEN, + ) + + if candidate is None: + raise RuntimeError( + f"Could not find optimal config for budget {config.target_budget:.2e} and label '{config.label}'" + ) + + params = candidate.model_config.total_trainable_params(MARIN_2025_RECIPE.vocab_size) + estimated_memory = MARIN_2025_RECIPE.estimate_memory_bytes(candidate) + + # Compute TPU type and gradient accumulation settings + max_cores = int(MAX_TPU_TYPE.split("-")[1]) + num_chips = max_cores // 2 + max_memory = num_chips * HBM_PER_CHIP_GIB * 1024**3 + + per_device_parallelism: int | None = None + if estimated_memory <= max_memory: + # Fits without gradient accumulation + tpu_type = pick_v5p_type(estimated_memory) + else: + # Need gradient accumulation to fit in MAX_TPU_TYPE + tpu_type = MAX_TPU_TYPE + microbatch_size = candidate.batch_size + while (microbatch_size / candidate.batch_size) * estimated_memory > max_memory: + microbatch_size //= 2 + if microbatch_size < num_chips: + raise ValueError( + f"Cannot fit model in {MAX_TPU_TYPE}: need microbatch >= {num_chips}, got {microbatch_size}" + ) + per_device_parallelism = microbatch_size // num_chips + + print( + f"Optimal config for {config.target_budget:.2e} FLOPs:\n" + f" hidden_dim={candidate.model_config.hidden_dim}, layers={candidate.model_config.num_layers}\n" + f" params={params:.2e}, tokens={candidate.tokens:.2e}\n" + f" batch_size={candidate.batch_size}, train_steps={candidate.train_steps}\n" + f" estimated_memory={estimated_memory / 1e9:.2f} GB -> {tpu_type}\n" + f" per_device_parallelism={per_device_parallelism or 'None (no grad accum)'}" + ) + + # For very large models, use aggressive gradient checkpointing to reduce memory + # Following exp1295_32b.py pattern: offload only carries, not inputs + model_config = candidate.model_config + if config.target_budget >= 1e21: + from haliax import ScanCheckpointPolicy + + model_config = replace(model_config, gradient_checkpointing=ScanCheckpointPolicy(save_carries="offload")) + logger.info("Using offload carries gradient checkpointing for large model") + + # Build TrainLmConfig directly (like old run_scaling_ladder_rung) + # config.tokenized is already processed by executor's instantiate_config + data = config.tokenized + if config.validation_configs: + # Merge validation configs into the data mixture with weight 0 + new_configs = { + **data.configs, + **{k: v for k, v in config.validation_configs.items() if k not in data.configs}, + } + if isinstance(data.train_weights, dict): + new_weights = { + **data.train_weights, + **{name: 0.0 for name in config.validation_configs if name not in data.train_weights}, + } + else: + # Varying weights case + new_weights = [ + (step_idx, {**weights, **{name: 0.0 for name in config.validation_configs if name not in weights}}) + for step_idx, weights in data.train_weights + ] + data = replace(data, configs=new_configs, train_weights=new_weights) + + inner_config = train_lm.TrainLmConfig( + data=data, + trainer=TrainerConfig( + tracker=WandbConfig( + project="marin", + tags=[ + "optimal-training", + f"FLOPs={config.target_budget:.1e}", + f"label={config.label}", + f"N={params:.1e}", + ], + ), + mp=jmp.get_policy("p=f32,c=bfloat16"), + train_batch_size=candidate.batch_size, + per_device_parallelism=per_device_parallelism if per_device_parallelism else -1, + num_train_steps=candidate.train_steps, + steps_per_eval=1000, + checkpointer=CheckpointerConfig( + save_interval=timedelta(minutes=10), + keep=[dict(every=5000)], + ), + mesh=MeshConfig( + compute_mapping={ + "token": (ResourceAxis.REPLICA_DCN, ResourceAxis.REPLICA, ResourceAxis.DATA), + "token_repeat": (ResourceAxis.REPLICA_DCN, ResourceAxis.REPLICA, ResourceAxis.DATA), + } + ), + allow_nondivisible_batch_size=True, + ), + train_seq_len=SEQ_LEN, + model=model_config, + optimizer=candidate.optimizer_config, + ) + + pod_config = TrainLmOnPodConfig( + train_config=inner_config, + resources=ResourceConfig.with_tpu(tpu_type), + output_path=config.output_path, + ) + + logger.info(f"Launching training with resources: {pod_config.resources}") + run_levanter_train_lm(pod_config) + + +# --- Step 1: IsoFLOP Analysis --- +# Creates scaling law fits from the training runs +analysis_step = ExecutorStep( + name=f"{EXPERIMENT_NAME}-analysis", + fn=run_isoflop_analysis_step, + config=IsoFlopAnalysisConfig( + training_runs=[r.as_input_name() for r in nemotron_training], + output_path=this_output_path(), + recipe=MARIN_2025_RECIPE, + ), +) + +# --- Create validation configs --- +# Convert validation TokenizerSteps to LMDatasetSourceConfig at module import time. +# This way instantiate_config resolves InputName paths before run_optimal_training runs. +validation_steps = default_validation_sets(tokenizer=llama3_tokenizer) +validation_configs = { + name: step_to_lm_mixture_component(step, include_raw_paths=False) for name, step in validation_steps.items() +} + +# --- Step 2: Optimal Training Runs --- +# Train compute-optimal models at each target budget +optimal_runs: list[ExecutorStep] = [] +for budget in TARGET_BUDGETS: + step = ExecutorStep( + name=f"{EXPERIMENT_NAME}-optimal-{budget:.0e}", + fn=run_optimal_training, + config=OptimalTrainingConfig( + analysis_output_path=analysis_step.as_input_name(), + target_budget=budget, + label=LABEL, + output_path=this_output_path(), + tokenized=nemotron_mix, + validation_configs=validation_configs, + ), + ) + optimal_runs.append(step) + +# All steps for this experiment +all_steps = [analysis_step, *optimal_runs] + +if __name__ == "__main__": + executor_main(steps=all_steps) diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index ca68974aae..961b051e65 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -12,350 +12,674 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Generate ISOFlop sweep steps for varying model sizes on a target datasett. +"""Generate ISOFlop sweep steps for varying model sizes on a target dataset. This script constructs `ExecutorStep` objects that train models of different -sizes while keeping the total training FLOPs roughly constant. It is intended -as a lightweight scaffold for ISOFlop scaling law experiments. +sizes while keeping the total training FLOPs roughly constant. """ -import dataclasses +import logging import math import os +import re + +import json +import fsspec + +from collections.abc import Iterator, Sequence from dataclasses import dataclass, replace from levanter.data.text import LMMixtureDatasetConfig from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig +from levanter.models.llama import LlamaConfig from levanter.models.qwen import Qwen3Config from levanter.optim.cautious import CautiousConfig -from levanter.optim.config import OptimizerConfig -from levanter.utils.flop_utils import lm_flops_per_token +from experiments.evals.evals import default_eval +from experiments.evals.task_configs import EvalTaskConfig from experiments.common_pile.tokenize_common_pile import comma_main_mixture from experiments.defaults import default_tokenize, default_train -from experiments.llama import compute_num_parameters, llama3_tokenizer -from experiments.metrics.wandb_related import get_vocab_size_for_tokenizer +from experiments.llama import llama3_tokenizer from experiments.pretraining_datasets.simple import downloads from experiments.simple_train_config import SimpleTrainConfig from experiments.tootsie.exp1295_32b import nemotron_mix from fray.cluster import ResourceConfig from marin.execution.executor import ExecutorStep, InputName, executor_main -from marin.processing.tokenize import lm_mixture_data_config +from marin.processing.tokenize import get_vocab_size_for_tokenizer, lm_mixture_data_config + +from marin.scaling_laws import ( + CandidateConfig, + FitScalingLawsResult, + IsoFlopRecord, + ScalingRecipe, + fit_scaling_laws, + pick_v5p_type, + round_flops_to_bucket, +) +from marin.scaling_laws.eval_metrics_reader import read_eval_records +from marin.utilities.wandb_utils import WANDB_ENTITY, WANDB_PROJECT -DEFAULT_BUDGETS = [1e18, 3e18, 6e18, 1e19, 3e19, 6e19, 1e20] -MLP_RATIO = 4 +logger = logging.getLogger(__name__) -# TPU v5p hardware constants for memory estimation -# Constants for TPU v5p -HBM_PER_CHIP_GIB = 95 -CORES_PER_CHIP = 2 -V5P_CORE_OPTIONS = [8, 16, 32, 128, 256, 512] # TPU slices +DEFAULT_BUDGETS: tuple[float, ...] = (1e18, 3e18, 6e18, 1e19, 3e19, 6e19, 1e20) +LEGACY_BUDGETS: tuple[float, ...] = (3e18, 9e18, 1.8e19, 3e19, 9e19, 1.8e20, 3e20) +DEFAULT_SEQ_LEN: int = 4096 +DEFAULT_STEPS_PER_RUN: int = 2**16 +# ---------------- Levanter WandB Metric Keys ---------------- +# These keys correspond to the metrics logged by Levanter's training callbacks. +THROUGHPUT_TOKENS_KEY = "throughput/total_tokens" +THROUGHPUT_GFLOPS_KEY = "throughput/total_gflops" +PARAMETER_COUNT_KEY = "parameter_count" +MODEL_CONFIG_KEY = "model" +TRAINER_CONFIG_KEY = "trainer" +DEFAULT_METRIC_KEY = "eval/paloma/c4_en/bpb" -def estimate_bytes( - param_count: int, - hidden_dim: int, - num_layers: int, - batch: int, - seq_len: int, - vocab: int, - optim_mult: int = 3, - dtype_size: int = 4, - fudge_factor: float = 2, -) -> int: - """ - Estimate float32 memory usage (in bytes) for one training step. - Note(Will): I had to do more fudging than expected on this, - but not seems to work ok. - - Parameters: - - hidden_dim: model hidden size - - num_layers: number of Transformer layers - - batch, seq_len: training batch size and sequence length - - vocab: vocabulary size - - optim_mult: optimizer memory multiplier (e.g., 100x for Adam + states) - - dtype_size: bytes per float (4 for float32) - - fudge_factor: safety margin for extra memory - Returns: - - total estimated memory in bytes - """ - param_bytes = param_count * optim_mult * dtype_size +# ---------------- Levanter Metrics Transform ---------------- - act_bytes = (batch * seq_len) * ((hidden_dim * num_layers) + vocab * fudge_factor) - total_bytes = param_bytes + act_bytes - return int(total_bytes) * fudge_factor +def parse_isoflop_run_name(run_name: str) -> str | None: + """Parse experiment name from isoflop run name. + Supports two formats: + - New: isoflop-{budget}-N{params}-B{batch}-{experiment_name} + E.g., 'isoflop-1e+18-N1e+08-B128-nemo-wider-depth-adapt' + - Legacy: isoflop-{budget}-d{hidden}-L{layers}-B{batch}-{experiment_name} + E.g., 'isoflop-1e+19-d2048-L16-B1024-nemo-wider-depth-adapt' -def pick_v5p_type( - config: Qwen3Config, - hidden: int, - layers: int, - batch: int, - seq_len: int, - vocab: int, -) -> str: + Optionally with a trailing - which is ignored. + + Returns experiment_name or None if parsing fails. """ - Select the smallest TPU v5p slice that fits the model in float32. + # Strip optional - suffix + run_name = re.sub(r"-[0-9a-fA-F]{6}$", "", run_name) + + # New format: isoflop-{budget}-N{params}-B{batch}-{experiment_name} + new_pattern = r"isoflop-(?:[0-9.e+]+)-N(?:[0-9.e+]+)-B(?:\d+)-(.+)" + match = re.match(new_pattern, run_name) + if match: + return match.group(1) + + # Legacy format: isoflop-{budget}-d{hidden}-L{layers}-B{batch}-{experiment_name} + legacy_pattern = r"isoflop-(?:[0-9.e+]+)-d(?:\d+)-L(?:\d+)-B(?:\d+)-(.+)" + match = re.match(legacy_pattern, run_name) + if match: + return match.group(1) + + return None + + +def transform_levanter_metrics( + raw_records: list[dict], + metric_key: str = DEFAULT_METRIC_KEY, + label_map: dict[str, str] | None = None, + min_flops: float = 1e18, +) -> list[IsoFlopRecord]: + """Transform raw Levanter metrics into IsoFlopRecord list. + + Args: + raw_records: Raw records from read_raw_records(), each containing + 'config', 'summary', and 'run_path' keys. + metric_key: Which metric to use (default: eval/paloma/c4_en/bpb). + label_map: Optional mapping from experiment_name -> display label. + min_flops: Minimum FLOP threshold to include (default: 1e18). Returns: - - TPU slice name, e.g., "v5p-8" or "v5p-32" + List of IsoFlopRecord for records that have all required fields. + Records missing required fields are logged and skipped. """ - param_count = compute_num_parameters(config, vocab) - need_bytes = estimate_bytes(param_count, hidden, layers, batch, seq_len, vocab) - chip_bytes = HBM_PER_CHIP_GIB * 1024**3 - chips = math.ceil(need_bytes / chip_bytes) - cores_req = chips * CORES_PER_CHIP + records = [] - valid = [c for c in V5P_CORE_OPTIONS if c >= cores_req] - if not valid: - raise ValueError(f"Model too large for available v5p slices (need {cores_req} cores).") + for raw in raw_records: + run_path = raw.get("run_path", "") + run_name = os.path.basename(run_path.rstrip("/")) - return f"v5p-{min(valid)}" + summary = raw.get("summary", {}) or {} + # Extract tokens + tokens = summary.get(THROUGHPUT_TOKENS_KEY) + if tokens is None: + logger.warning(f"Missing {THROUGHPUT_TOKENS_KEY} for run {run_name}, skipping") + continue -@dataclass -class IsoFlopSweepConfig: - """Configuration for generating ISOFlop sweep steps.""" + # Extract FLOPs (convert GFLOPs to FLOPs and bucket) + total_gflops = summary.get(THROUGHPUT_GFLOPS_KEY) + if total_gflops is None: + logger.warning(f"Missing {THROUGHPUT_GFLOPS_KEY} for run {run_name}, skipping") + continue + flops = round_flops_to_bucket(total_gflops * 1e9) - tokenized_dataset: InputName | str - tokenizer: str = "stanford-crfm/marin-tokenizer" - budgets: list[float] = dataclasses.field(default_factory=lambda: DEFAULT_BUDGETS) - seq_len: int = 4096 - steps_per_run: int = 2**16 - flop_tolerance: float = 0.01 - base_hidden_layer_ratio: int = 64 - hidden_head_ratio: int = 128 - lr_constant: float = 0.33 - min_hidden_pow: int = 9 - max_hidden_pow: int = 12 - base_optimizer_config: OptimizerConfig = dataclasses.field( - default_factory=lambda: CautiousConfig( - learning_rate=1.0, # Placeholder - weight_decay=0.1, - min_lr_ratio=0.0, - warmup=0.1, - beta1=0.95, - beta2=0.98, - epsilon=1e-15, - max_grad_norm=1, - adamc_weight_decay=True, - lr_schedule="linear", - decay=0.2, - ), - ) - base_train_config: SimpleTrainConfig = dataclasses.field( - default_factory=lambda: SimpleTrainConfig( - resources=ResourceConfig.with_tpu("v5p-8"), - train_batch_size=1, - num_train_steps=50_000, - learning_rate=1.0, # Placeholder - weight_decay=0.1, - min_lr_ratio=0.0, - lr_schedule="linear", - decay=0.2, + if flops < min_flops: + continue + + # Extract metric + metric = summary.get(metric_key) + if metric is None: + logger.warning(f"Missing metric {metric_key} for run {run_name}, skipping") + continue + + # Extract params (required) + params = summary.get(PARAMETER_COUNT_KEY) + if params is None: + logger.warning(f"Missing {PARAMETER_COUNT_KEY} for run {run_name}, skipping") + continue + + # Determine label from run name + exp_name = parse_isoflop_run_name(run_name) or run_name + if label_map and exp_name in label_map: + label = label_map[exp_name] + else: + label = exp_name + + records.append( + IsoFlopRecord( + tokens=float(tokens), + metric=float(metric), + flops=float(flops), + params=float(params), + label=label, + ) ) - ) + logger.info(f"Transformed {len(records)} records from {len(raw_records)} raw records") + return records -def round_to_power_of_two(x: float) -> int: - """Round ``x`` to the nearest power of two.""" +def _round_to_power_of_two(x: float) -> int: + """Round x UP to the nearest power of 2.""" if x <= 1: return 1 return 2 ** math.ceil(math.log2(x)) -def compute_total_flops( - batch: int, +def _format_run_name( + budget: float, + hidden_size: int, num_layers: int, - hidden: int, - intermediate: int, - num_kv_heads: int, - num_heads: int, - steps: int, - seq_len: int, - vocab_size: int, -) -> float: - """Compute total training FLOPs using Levanter utilities.""" - - flops_per_token = lm_flops_per_token( - hidden, - intermediate, - num_layers, - num_kv_heads, - num_heads, - seq_len, - vocab_size, - glu=True, - ) - return flops_per_token * batch * steps * seq_len + batch_size: int, + experiment_name: str, +) -> str: + """Format run name using architecture details (hidden size and layers). + + Format: isoflop-{budget}-d{hidden}-L{layers}-B{batch}-{experiment_name} + """ + return f"isoflop-{budget:.0e}-d{hidden_size}-L{num_layers}-B{batch_size}-{experiment_name}" + + +@dataclass(frozen=True) +class Marin2025Recipe: + """Marin 2025 scaling recipe with all hyperparameters and formulas. + This recipe implements all the Marin-specific decisions for scaling experiments. + The vocab_size is derived from the tokenizer, making the recipe self-contained + for all model configuration decisions. + """ -def candidate_configs(cfg: IsoFlopSweepConfig, budget: float): - """Yield candidate model configurations within the FLOP budget.""" + name: str = "marin-2025" + tokenizer: str = "stanford-crfm/marin-tokenizer" + """Tokenizer to use. vocab_size is derived from this.""" - vocab_size = get_vocab_size_for_tokenizer(cfg.tokenizer) + @property + def vocab_size(self) -> int: + """Vocabulary size derived from the tokenizer.""" + return get_vocab_size_for_tokenizer(self.tokenizer) - if budget > 9e18: - step_size = 256 - else: - step_size = 128 + # --- Learning rate scaling --- + # lr = lr_constant * sqrt(batch_size) / hidden_dim + lr_constant: float = 0.33 - for hidden_size in range(2**cfg.min_hidden_pow, (2**cfg.max_hidden_pow) + 1, step_size): + # --- Beta2 scaling for Adam --- + # beta2 = beta2_base ** (batch_size / beta2_batch_divisor) + beta2_base: float = 0.98 + beta2_batch_divisor: float = 128 + + # --- Optimizer hyperparameters --- + weight_decay: float = 0.1 + min_lr_ratio: float = 0.0 + warmup: float = 0.1 + beta1: float = 0.95 + epsilon: float = 1e-15 + max_grad_norm: float = 1.0 + lr_schedule: str = "linear" + decay: float = 0.2 + + # --- Architecture ratios --- + mlp_ratio: int = 4 + hidden_head_ratio: int = 128 + + # --- Architecture formula for depth-to-width scaling --- + base_hidden_layer_ratio: int = 64 + layer_scaling_factor: float = 4.0 + layer_formula_offset: int = 9 + + # --- Constraints --- + max_learning_rate: float = 0.01 + min_batch_size: int = 8 + max_batch_size: int = 8192 + # max_params scales with sqrt(budget) above 3e20, with floor of 12B and ceiling of 1T + base_max_params: float = 12e9 + base_max_params_budget: float = 3e20 + global_max_params: float = 1e12 + + # --- Search step sizes for isoflop sweeps --- + small_budget_step_size: int = 128 + large_budget_step_size: int = 256 + budget_step_threshold: float = 9e18 + + def _compute_learning_rate(self, batch_size: int, hidden_dim: int) -> float: + """Compute learning rate from batch size and hidden dim.""" + return (self.lr_constant * math.sqrt(batch_size)) / hidden_dim + + def _compute_beta2(self, batch_size: int) -> float: + """Compute beta2 from batch size.""" + return self.beta2_base ** (batch_size / self.beta2_batch_divisor) + + def compute_num_layers(self, hidden_size: int) -> int: + """Compute number of layers from hidden size using the depth-width formula.""" hs_pow = math.log2(hidden_size) - intermediate_dim = hidden_size * MLP_RATIO - num_layers = round(hidden_size / (cfg.base_hidden_layer_ratio + (hs_pow * 4) - cfg.min_hidden_pow)) - n_heads = max(1, hidden_size // cfg.hidden_head_ratio) - n_kv_heads = n_heads - - batch_exact = budget / compute_total_flops( - 1, - num_layers, - hidden_size, - intermediate_dim, - n_kv_heads, - n_heads, - cfg.steps_per_run, - cfg.seq_len, - vocab_size, + return round( + hidden_size + / (self.base_hidden_layer_ratio + (hs_pow * self.layer_scaling_factor) - self.layer_formula_offset) ) - batch_size = round_to_power_of_two(batch_exact) - lr = (cfg.lr_constant * math.sqrt(batch_size)) / hidden_size - while lr > 0.01: - batch_size //= 2 - lr = (cfg.lr_constant * math.sqrt(batch_size)) / hidden_size - b2 = 0.98 ** (batch_size / 128) # https://arxiv.org/pdf/2507.07101 - - if batch_size < 8: - continue + def _get_step_size(self, budget: float) -> int: + """Get hidden_size search step size based on budget.""" + if budget > self.budget_step_threshold: + return self.large_budget_step_size + return self.small_budget_step_size + + def _max_params_for_budget(self, budget: float) -> float: + """Compute max_params as a function of budget. + + Returns base_max_params for budgets <= base_max_params_budget, + then scales with sqrt(budget) for larger budgets, capped at global_max_params. + """ + scaling = self.base_max_params * math.sqrt(budget / self.base_max_params_budget) + return min(max(self.base_max_params, scaling), self.global_max_params) + + def _build_model_config_from_hidden_size(self, hidden_size: int, seq_len: int = DEFAULT_SEQ_LEN) -> LlamaConfig: + """Build model config from hidden_size directly.""" + if hidden_size % self.hidden_head_ratio != 0: + raise ValueError( + f"hidden_size ({hidden_size}) must be divisible by hidden_head_ratio ({self.hidden_head_ratio}). " + f"Got remainder {hidden_size % self.hidden_head_ratio}." + ) + num_layers = self.compute_num_layers(hidden_size) + intermediate_dim = hidden_size * self.mlp_ratio + n_heads = max(1, hidden_size // self.hidden_head_ratio) + + return Qwen3Config( + hidden_dim=hidden_size, + intermediate_dim=intermediate_dim, + num_layers=num_layers, + num_heads=n_heads, + num_kv_heads=n_heads, + max_seq_len=seq_len, + rope=Llama3RotaryEmbeddingsConfig(), + ) - steps_exact = budget / compute_total_flops( - batch_size, - num_layers, - hidden_size, - intermediate_dim, - n_kv_heads, - n_heads, - 1, - cfg.seq_len, - vocab_size, + def estimate_memory_bytes( + self, + candidate: CandidateConfig, + optim_mult: int = 3, + dtype_size: int = 4, + fudge_factor: float = 2.0, + ) -> int: + """Estimate memory usage in bytes for training. + + Accounts for: + - Parameters + optimizer state (master weights, momentum, variance) + - Activation memory including attention O(seq²) term + - Embedding table memory + """ + model_config = candidate.model_config + batch_size = candidate.batch_size + seq_len = model_config.max_seq_len + hidden = model_config.hidden_dim + intermediate = getattr(model_config, "intermediate_dim", hidden * self.mlp_ratio) + layers = model_config.num_layers + # Parameters + optimizer (master weights + momentum + variance in fp32) + param_count = model_config.total_trainable_params(self.vocab_size) + param_bytes = param_count * optim_mult * dtype_size + + # Activation memory per layer (bf16 = 2 bytes) + # - Hidden states: batch * seq * hidden + # - Attention Q/K/V/O: batch * seq * hidden * 4 (flash attention is O(seq), not O(seq²)) + # - MLP intermediate: batch * seq * intermediate + hidden_act = batch_size * seq_len * hidden * 2 + attn_act = batch_size * seq_len * hidden * 4 * 2 # Q, K, V, output tensors (flash attn) + mlp_act = batch_size * seq_len * intermediate * 2 + per_layer_act = hidden_act + attn_act + mlp_act + + # Activation memory scales with layers. Even with gradient checkpointing, + # we need significant memory for recomputation and gradient storage. + # Empirically, memory scales roughly as layers * 0.75 for large models. + act_bytes = per_layer_act * max(layers * 3 // 4, 4) + + # Embedding table (often not sharded well) + embed_bytes = self.vocab_size * hidden * 2 + + total_bytes = param_bytes + act_bytes + embed_bytes + return int(total_bytes * fudge_factor) + + def build_model_configs( + self, + budget: float, + seq_len: int = DEFAULT_SEQ_LEN, + ) -> Iterator[LlamaConfig]: + """Yield candidate model architectures for the given FLOP budget. + + Uses wide bounds (2**9 to 2**17) and relies on batch size filtering + in build_candidate_config to select valid configurations. + """ + step_size = self._get_step_size(budget) + + for hidden_size in range(2**9, 2**17, step_size): + yield self._build_model_config_from_hidden_size(hidden_size, seq_len) + + def build_candidate_config( + self, + model_config: LlamaConfig, + tokens: float, + flops_budget: float, + seq_len: int = DEFAULT_SEQ_LEN, + ) -> CandidateConfig | None: + """Build complete training config for a model and token count. + + Returns None if the configuration is invalid (e.g., batch_size < minimum + after learning rate constraints are applied). + """ + hidden_size = model_config.hidden_dim + + # Start with batch_size that gives us ~DEFAULT_STEPS_PER_RUN steps + target_steps = DEFAULT_STEPS_PER_RUN + batch_exact = tokens / (target_steps * seq_len) + batch_size = _round_to_power_of_two(batch_exact) + + # Adjust batch_size to respect learning rate constraints + lr = self._compute_learning_rate(batch_size, hidden_size) + while lr > self.max_learning_rate: + batch_size //= 2 + lr = self._compute_learning_rate(batch_size, hidden_size) + + # Return None if batch_size is outside valid range + if batch_size < self.min_batch_size or batch_size > self.max_batch_size: + return None + + # Compute train_steps to achieve target tokens + train_steps = round(tokens / (batch_size * seq_len)) + + # Compute actual tokens after rounding + actual_tokens = batch_size * train_steps * seq_len + + # Build optimizer config + beta2 = self._compute_beta2(batch_size) + optimizer_config = CautiousConfig( + learning_rate=lr, + weight_decay=self.weight_decay, + min_lr_ratio=self.min_lr_ratio, + warmup=self.warmup, + beta1=self.beta1, + beta2=beta2, + epsilon=self.epsilon, + max_grad_norm=self.max_grad_norm, + adamc_weight_decay=True, + lr_schedule=self.lr_schedule, + decay=self.decay, ) - train_steps = round(steps_exact) - - achieved_flops = compute_total_flops( - batch_size, - num_layers, - hidden_size, - intermediate_dim, - n_kv_heads, - n_heads, - train_steps, - cfg.seq_len, - vocab_size, + + return CandidateConfig( + model_config=model_config, + optimizer_config=optimizer_config, + batch_size=batch_size, + train_steps=train_steps, + tokens=actual_tokens, + flops_budget=flops_budget, ) - if abs(achieved_flops - budget) / budget > cfg.flop_tolerance: - continue + def candidates_for_budget( + self, + budget: float, + seq_len: int = DEFAULT_SEQ_LEN, + ) -> Iterator[CandidateConfig]: + """Yield valid candidate training configs for the given FLOP budget.""" + max_params = self._max_params_for_budget(budget) + for model_config in self.build_model_configs(budget, seq_len): + # Skip models that exceed budget-dependent max_params + params = model_config.total_trainable_params(self.vocab_size) + if params > max_params: + continue + flops_per_token = model_config.flops_per_token(self.vocab_size, seq_len) + tokens = budget / (3 * flops_per_token) + candidate = self.build_candidate_config(model_config, tokens, budget, seq_len) + if candidate is not None: + yield candidate - yield (hidden_size, intermediate_dim, num_layers, n_heads, n_kv_heads, batch_size, train_steps, lr, b2) - - -def generate_isoflop_steps(config: IsoFlopSweepConfig, experiment_name: str) -> list[ExecutorStep]: - """Generate executor steps for an ISOFlop sweep.""" - - steps: list[ExecutorStep] = [] - metadata = [] - vocab_size = get_vocab_size_for_tokenizer(config.tokenizer) - - for budget in config.budgets: - for ( - hidden_size, - intermediate_dim, - num_layers, - n_heads, - n_kv_heads, - batch_size, - train_steps, - lr, - b2, - ) in candidate_configs(config, budget): - model_cfg = Qwen3Config( - max_seq_len=config.seq_len, - hidden_dim=hidden_size, - intermediate_dim=intermediate_dim, - num_heads=n_heads, - num_kv_heads=n_kv_heads, - num_layers=num_layers, - rope=Llama3RotaryEmbeddingsConfig(), - ) - tpu_type = pick_v5p_type( - config=model_cfg, - hidden=hidden_size, - layers=num_layers, - batch=batch_size, - seq_len=config.seq_len, - vocab=vocab_size, - ) - optimizer_cfg = replace(config.base_optimizer_config, learning_rate=lr, beta2=b2) - train_cfg = replace( - config.base_train_config, - train_batch_size=batch_size, - learning_rate=lr, - num_train_steps=train_steps, - resources=ResourceConfig.with_tpu(tpu_type), - optimizer_config=optimizer_cfg, - ) - run_name = f"isoflop-{budget:.0e}-d{hidden_size}-L{num_layers}-B{batch_size}-{experiment_name}" - step = default_train( - name=run_name, - tokenized=config.tokenized_dataset, - model_config=model_cfg, - train_config=train_cfg, - eval_harness_tasks=[], - tags=( - f"FLOPs={budget:.1e}", - f"d={hidden_size}", - f"L={num_layers}", - f"B={batch_size}", - f"steps={train_steps}", - f"tpu={tpu_type}", - ), - ) - metadata.append((budget, hidden_size, num_layers, batch_size, train_steps)) - # Reuse checkpoints by pinning every sweep run to a deterministic directory. - static_output_path = os.path.join( - "checkpoints", - "isoflop", - run_name, - ) - steps.append(step.with_output_path(static_output_path)) +MARIN_2025_RECIPE = Marin2025Recipe() +"""Default Marin scaling recipe.""" + + +# ---------------- IsoFlop Analysis ---------------- + + +@dataclass(frozen=True, kw_only=True) +class IsoFlopAnalysisConfig: + """Configuration for IsoFLOP scaling law analysis. + + The training_runs field creates blocking dependencies on the training jobs. + This config is for use with ExecutorStep. + """ - return steps, metadata + training_runs: Sequence[str] + """Training run output paths (executor resolves InputName to str at runtime).""" + output_path: str + """Where to write analysis outputs.""" + + recipe: ScalingRecipe + """Scaling recipe for computing optimal hyperparameters.""" + + metric_key: str = DEFAULT_METRIC_KEY + """Metric to use for loss (default: eval/paloma/c4_en/bpb).""" + + label_map: tuple[tuple[str, str], ...] | None = None + """Optional mapping from experiment_name -> display label as tuple of pairs.""" + + metrics_filename: str = "tracker_metrics.jsonl" + """Name of the metrics file within each checkpoint directory.""" + + wandb_entity_project: str = f"{WANDB_ENTITY}/{WANDB_PROJECT}" + """WandB entity/project to query for backfill (format: 'entity/project').""" + + +def run_isoflop_analysis_step(config: IsoFlopAnalysisConfig) -> FitScalingLawsResult: + """Execute IsoFLOP scaling law analysis. + + This is the experiment step function that: + 1. Reads raw metrics from training runs + 2. Transforms them using Levanter schema knowledge + 3. Runs the scaling law analysis + 4. Saves results to output_path + + Args: + config: Analysis config with training_runs and analysis settings + + Returns: + FitScalingLawsResult with fitted scaling laws + """ + # Read raw records from training runs + raw_records = read_eval_records( + training_runs=config.training_runs, + metrics_filename=config.metrics_filename, + wandb_entity_project=config.wandb_entity_project, + ) -def generate_isoflop_sweep( - tokenized: InputName | ExecutorStep | LMMixtureDatasetConfig, + if not raw_records: + logger.warning("No eval metrics found in training runs") + return FitScalingLawsResult(minima_records=[], scaling_fits={}, fit_curves={}) + + # Transform to typed records using Levanter schema knowledge + label_map = dict(config.label_map) if config.label_map else None + records = transform_levanter_metrics(raw_records, config.metric_key, label_map) + + if not records: + logger.warning("No valid isoflop data after transformation") + return FitScalingLawsResult(minima_records=[], scaling_fits={}, fit_curves={}) + + logger.info(f"Loaded {len(records)} runs for scaling law analysis") + labels = list(dict.fromkeys(r.label for r in records)) + flops_budgets = sorted(set(r.flops for r in records)) + logger.info(f"Labels found: {labels}") + logger.info(f"FLOP budgets: {flops_budgets}") + + # Run scaling law analysis + result = fit_scaling_laws(records) + + logger.info(f"Found {len(result.minima_records)} optimal configurations") + for label, scaling_fit in result.scaling_fits.items(): + logger.info(f" {label}: D* = {scaling_fit.A:.2e} * C^{scaling_fit.alpha:.3f}") + + # Save results + fs, _, _ = fsspec.get_fs_token_paths(config.output_path) + fs.makedirs(config.output_path, exist_ok=True) + + result_path = os.path.join(config.output_path, "isoflop_analysis_result.json") + result_dict = { + "minima_records": [ + { + "label": r.label, + "flops": r.flops, + "optimal_tokens": r.optimal_tokens, + "loss_at_optimal": r.loss_at_optimal, + "optimal_params": r.optimal_params, + "scaling_alpha": r.scaling_alpha, + "scaling_A": r.scaling_A, + } + for r in result.minima_records + ], + "scaling_fits": {k: list(v) for k, v in result.scaling_fits.items()}, + } + with fs.open(result_path, "w") as f: + json.dump(result_dict, f, indent=2) + logger.info(f"Saved results to {result_path}") + + # Save fit curves for downstream plotting + fit_curves_path = os.path.join(config.output_path, "fit_curves.json") + fit_curves_json = {f"{label}|{flops}": list(coeffs) for (label, flops), coeffs in result.fit_curves.items()} + with fs.open(fit_curves_path, "w") as f: + json.dump(fit_curves_json, f, indent=2) + logger.info(f"Saved fit curves to {fit_curves_path}") + + return result + + +def create_isoflop_sweep_steps( + tokenized: InputName | str | LMMixtureDatasetConfig, experiment_name: str, - **kwargs, -) -> list[ExecutorStep]: - sweep_cfg = IsoFlopSweepConfig(tokenized_dataset=tokenized, **kwargs) - steps, metadata = generate_isoflop_steps(sweep_cfg, experiment_name) + recipe: ScalingRecipe, + budgets: tuple[float, ...] = DEFAULT_BUDGETS, + eval_tasks: tuple[EvalTaskConfig, ...] | None = None, + seq_len: int = 4096, +) -> tuple[list[ExecutorStep], list[CandidateConfig]]: + """Create ExecutorSteps for an ISOFlop sweep. + + Args: + tokenized: Tokenized dataset to train on. + experiment_name: Name suffix for the experiment (e.g., 'nemo', 'dclm'). + recipe: ScalingRecipe with hyperparameters (includes vocab_size). + budgets: FLOP budgets to sweep over. + eval_tasks: Optional evaluation tasks to run after training. + seq_len: Sequence length for training. + + Returns: + A tuple of: + - steps: Training and evaluation ExecutorSteps for the sweep. + - candidates: CandidateConfig for each training run with full config details. + """ + candidates = [c for budget in budgets for c in recipe.candidates_for_budget(budget, seq_len)] + + # Base config for training runs (values overridden per-candidate) + base_train_config = SimpleTrainConfig( + resources=ResourceConfig.with_tpu("v5p-8"), + train_batch_size=1, + num_train_steps=50_000, + learning_rate=1.0, # Overridden via optimizer_config + ) - return steps, metadata + train_steps: list[ExecutorStep] = [] + eval_steps: list[ExecutorStep] = [] + + # Create ExecutorSteps for each candidate configuration + for candidate in candidates: + model_config = candidate.model_config + estimated_memory = recipe.estimate_memory_bytes(candidate) + tpu_type = pick_v5p_type(estimated_memory) + + # Use local naming with architecture details for backward compatibility + run_name = _format_run_name( + candidate.flops_budget, + model_config.hidden_dim, + model_config.num_layers, + candidate.batch_size, + experiment_name, + ) + output_path = f"checkpoints/isoflop/{run_name}" + + # Build tags for tracking + params = model_config.total_trainable_params(recipe.vocab_size) + tags = ( + f"FLOPs={candidate.flops_budget:.1e}", + f"N={params:.1e}", + f"B={candidate.batch_size}", + f"steps={candidate.train_steps}", + f"tokens={candidate.tokens:.1e}", + ) + train_cfg = replace( + base_train_config, + train_batch_size=candidate.batch_size, + learning_rate=candidate.optimizer_config.learning_rate, + num_train_steps=candidate.train_steps, + resources=ResourceConfig.with_tpu(tpu_type), + optimizer_config=candidate.optimizer_config, + ) -dclm_tokenized = dataclasses.replace( - default_tokenize( - name="dclm_baseline", - dataset=downloads["dclm_baseline"], - tokenizer=llama3_tokenizer, - ).with_output_path("tokenized/dclm_baseline-0206f1/"), -) + # Create training step + train_step = default_train( + name=run_name, + tokenized=tokenized, + model_config=model_config, + train_config=train_cfg, + eval_harness_tasks=[], + tags=tags, + ) + # Pin to static output path for checkpoint reuse + train_step = train_step.with_output_path(output_path) + train_steps.append(train_step) + + # Create evaluation step if eval tasks specified + if eval_tasks: + eval_step = default_eval( + train_step, + resource_config=train_cfg.resources, + evals=eval_tasks, + ) + eval_steps.append(eval_step) + + all_steps: list[ExecutorStep] = [*train_steps, *eval_steps] + return all_steps, candidates + + +# --- Tokenized Datasets --- + +dclm_tokenized = default_tokenize( + name="dclm_baseline", + dataset=downloads["dclm_baseline"], + tokenizer=llama3_tokenizer, +).with_output_path("tokenized/dclm_baseline-0206f1/") dclm_mix = lm_mixture_data_config( components={"dclm": dclm_tokenized}, @@ -363,13 +687,11 @@ def generate_isoflop_sweep( num_validation_sequences={"dclm": 1024}, ) -dolma3_mix_tokenized = dataclasses.replace( - default_tokenize( - name="dolma3_mix-150B-1025", - dataset=downloads["dolma3_mix_150b_1025"], - tokenizer=llama3_tokenizer, - ).with_output_path("tokenized/dolma3_mix-150B-1025-15d04ee/"), -) +dolma3_mix_tokenized = default_tokenize( + name="dolma3_mix-150B-1025", + dataset=downloads["dolma3_mix_150b_1025"], + tokenizer=llama3_tokenizer, +).with_output_path("tokenized/dolma3_mix-150B-1025-15d04ee/") dolma3_mix = lm_mixture_data_config( components={"dolma3_mix-150B-1025": dolma3_mix_tokenized}, @@ -377,14 +699,38 @@ def generate_isoflop_sweep( num_validation_sequences={"dolma3_mix-150B-1025": 1024}, ) + MARIN_SCALING_SUITES = { - "nemotron": generate_isoflop_sweep(nemotron_mix, experiment_name="nemo-wider-depth-adapt"), - "common_pile": generate_isoflop_sweep(comma_main_mixture(permutation_type="linear"), experiment_name="comma-mix"), - "common_pile_feistel": generate_isoflop_sweep( - comma_main_mixture(permutation_type="feistel"), experiment_name="comma-mix-feistel" + "nemotron": create_isoflop_sweep_steps( + tokenized=nemotron_mix, + experiment_name="nemo-wider-depth-adapt", + recipe=MARIN_2025_RECIPE, + budgets=LEGACY_BUDGETS, + ), + "common_pile": create_isoflop_sweep_steps( + tokenized=comma_main_mixture(permutation_type="linear"), + experiment_name="comma-mix", + recipe=MARIN_2025_RECIPE, + budgets=LEGACY_BUDGETS, + ), + "common_pile_feistel": create_isoflop_sweep_steps( + tokenized=comma_main_mixture(permutation_type="feistel"), + experiment_name="comma-mix-feistel", + recipe=MARIN_2025_RECIPE, + budgets=LEGACY_BUDGETS, + ), + "dclm-default": create_isoflop_sweep_steps( + tokenized=dclm_mix, + experiment_name="dclm-default", + recipe=MARIN_2025_RECIPE, + budgets=LEGACY_BUDGETS, + ), + "dolma3_mix_150b": create_isoflop_sweep_steps( + tokenized=dolma3_mix, + experiment_name="dolma3-mix-150b-1025", + recipe=MARIN_2025_RECIPE, + budgets=LEGACY_BUDGETS, ), - "dclm-default": generate_isoflop_sweep(dclm_mix, experiment_name="dclm-default"), - "dolma3_mix_150b": generate_isoflop_sweep(dolma3_mix, experiment_name="dolma3-mix-150b-1025"), } if __name__ == "__main__": diff --git a/experiments/metrics/wandb_related.py b/experiments/metrics/wandb_related.py index b74d902fe4..2c478bbf15 100644 --- a/experiments/metrics/wandb_related.py +++ b/experiments/metrics/wandb_related.py @@ -18,6 +18,7 @@ from typing import Any import wandb +from marin.processing.tokenize import get_vocab_size_for_tokenizer from marin.utilities.wandb_utils import WANDB_ENTITY, WANDB_PROJECT logger = logging.getLogger(__name__) @@ -145,26 +146,6 @@ def get_all_runs_over_period( return None -def get_vocab_size_for_tokenizer(tokenizer: str) -> int | None: - logger.info(f"Tokenizer:{tokenizer}") - if tokenizer == "EleutherAI/gpt-neox-20b": - vocab_size = 50_257 - elif tokenizer == "meta-llama/Meta-Llama-3.1-8B": - vocab_size = 128_256 - elif tokenizer == "stanford-crfm/marin-tokenizer": - vocab_size = 128_256 - elif tokenizer == "meta-llama/Llama-2-7b": - vocab_size = 32_000 - elif tokenizer == "gpt2": - vocab_size = 50_257 - else: - logger.error(f"Unknown tokenizer: {tokenizer}") - return None - - logger.info(f"Vocab size: {vocab_size}") - return vocab_size - - def count_params_for_run(run_id: str, entity=WANDB_ENTITY, project=WANDB_PROJECT) -> int | None: """ Retrieves the number of parameters for a specific WandB run. diff --git a/experiments/tootsie/exp654_scaling_tootsie.py b/experiments/tootsie/exp654_scaling_tootsie.py deleted file mode 100644 index d5393270a6..0000000000 --- a/experiments/tootsie/exp654_scaling_tootsie.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2025 The Marin Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from experiments.pretraining_datasets.dclm import dclm_mixture_config_llama3_wrong -from marin.execution.executor import executor_main -import dataclasses -import logging -from collections.abc import Sequence - -from levanter.data.text import LMMixtureDatasetConfig -from levanter.models.llama import LlamaConfig - -from experiments.defaults import default_train -from experiments.llama import llama_1_4b -from experiments.simple_train_config import SimpleTrainConfig -from fray.cluster import ResourceConfig -from marin.execution.executor import ExecutorStep, InputName - -DEFAULT_MODEL_CONFIG = LlamaConfig( - max_seq_len=4096, - hidden_dim=2048, - intermediate_dim=7168, - num_heads=16, - num_kv_heads=8, - num_layers=16, -) - -# WSD-S training configuration -DEFAULT_SWEEP_TRAIN_CONFIG = SimpleTrainConfig( - resources=ResourceConfig.with_tpu("v4-128"), - train_batch_size=1024, - learning_rate=1e-3, # will be replaced in the scaling law suite - weight_decay=0.1, - # https://arxiv.org/pdf/2412.04403 gets 4 points per run. this gives us 5 - num_train_steps=50000, # 4096 * 1024 * 50000 = ~200B tokens - cycle_length=10000, # 5 cycles with 10000 steps/cycle - steps_per_eval=10000, # same as cycle length - warmup=1000, # initial warmup - decay=0.1, # 10% decay - lr_schedule="inv", # inv decay -) - - -# TODO(dlwh): in an old levanter branch (wandb_sweeps) i had fancier sweep generation stuff for doing surgery on the -# config. Consider using that. - - -def scaling_law_suite( - sweep_name: str, - tokenized: InputName | ExecutorStep | LMMixtureDatasetConfig, - widths: Sequence[int] = (512, 768, 1024, 1536, 2048), - base_model_config: LlamaConfig = llama_1_4b, - tags: Sequence[str] = (), - *, - intermediate_scale: float = 8, - training_config: SimpleTrainConfig = DEFAULT_SWEEP_TRAIN_CONFIG, - base_lr: float = 3e-4 * 4096, - max_lr: float = 5e-3, -) -> Sequence[ExecutorStep]: - """ - Provides width-wise scaling suite using WSD-S (or other) training configurations. - - Assumptions (consistent with llama 3): - * 128 head_dim - * 8 key-value heads unless that doesn't work with head_dim = 128 - * intermediate_dim = _round_to_multiple(intermediate_scale * width, 128) - * all widths are divisible by 128 - * peak lr is scaled to be base_lr / width, but clamped to max_lr - - Args: - sweep_name: prefix for the sweep name. runs will be named {sweep_name}-{width}-{hash} - base_model_config: base model configuration. Sweep will be generated by varying the width. - tokenized: input data for training - widths: range of widths to sweep over - training_config: training configuration - - References: - * default widths are from https://arxiv.org/pdf/2412.04403 table 1 (plus 512) - * incredibly wide intermediate_scale is based on the same table - * base_lr is based on llama 3 (https://arxiv.org/pdf/2407.21783 table 3) - * max_lr is a reasonable value that is not too high - * default model config (1_4b) gives the number of layers used in https://arxiv.org/pdf/2412.04403 table 1 - * lr scaling is based on µP/µTransfer: https://arxiv.org/pdf/2203.03466 where generally speaking, lr should - be scaled down by the width of the model. - """ - - steps = [] - for w in widths: - intermediate_dim = _round_to_multiple(intermediate_scale * w, 128) - head_size = 128 # keeping this 128 means we can use splash attention - num_heads = w // head_size - num_kv_heads = min(num_heads, 8) - assert num_heads * head_size == w, f"Number of heads must divide width: {w} % {head_size} != 0" - - # if num_kv_heads doesn't divide num_heads, we need to adjust num_kv_heads - if num_heads % num_kv_heads != 0: - num_kv_heads = num_heads - - model_config = dataclasses.replace( - base_model_config, - hidden_dim=w, - intermediate_dim=intermediate_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - ) - - lr = min(base_lr / w, max_lr) - training_config = dataclasses.replace(training_config, learning_rate=lr) - - logging.info(f"Creating training step for {sweep_name}-{w} with width {w} and lr {lr}") - - steps.append( - default_train( - name=f"{sweep_name}-{w}", - tokenized=tokenized, - model_config=model_config, - train_config=training_config, - tags=tags, - ) - ) - return steps - - -def _round_to_multiple(x, multiple): - return int(multiple * round(x / multiple)) - - -TAG = ["654_scaling_tootsie"] - -suite = scaling_law_suite(sweep_name="tootsie-scaling", tokenized=dclm_mixture_config_llama3_wrong, tags=TAG) - -if __name__ == "__main__": - executor_main( - steps=[ - *suite, - ], - description="scaling law suite to predict performance of 8B model on DCLM mix", - ) diff --git a/lib/levanter/src/levanter/eval.py b/lib/levanter/src/levanter/eval.py index 6d7b77e6ca..e8191d3abc 100644 --- a/lib/levanter/src/levanter/eval.py +++ b/lib/levanter/src/levanter/eval.py @@ -3,12 +3,15 @@ import asyncio import dataclasses +import json import logging +import os import warnings from collections import defaultdict from typing import Callable, Mapping, Optional, Sequence, TypeVar import equinox as eqx +import fsspec import jax.numpy as jnp import jmp import numpy as np @@ -173,6 +176,7 @@ def cb_tagged_lm_evaluate( eval_ema: bool = True, prefix: str = "eval", mp: jmp.Policy = None, + checkpoint_path: Optional[str] = None, ) -> Callable[[StepInfo], None]: """ Evaluates multiple tagged datasets using a given evaluation function. @@ -196,6 +200,7 @@ def cb_tagged_lm_evaluate( prefix: The prefix to use for logging the losses eval_current: Whether to evaluate the model's current parameters eval_ema: Whether to evaluate the EMA model (or other model averaged model) + checkpoint_path: If provided, write eval metrics to a JSONL file in this directory """ evaluator = TaggedEvaluator( @@ -207,10 +212,12 @@ def cb_tagged_lm_evaluate( def eval_callback(step: StepInfo): step_count = step.step + metrics_to_write = {} if eval_current: log_dict = eval_model(evaluator, step.model, prefix=prefix) levanter.tracker.log(log_dict, step=step_count) + metrics_to_write.update(log_dict) if not eval_current and step.state.model_averaging is None: raise ValueError("Cannot evaluate EMA model without model averaging, but you only want to evaluate EMA") @@ -218,6 +225,21 @@ def eval_callback(step: StepInfo): if eval_ema and step.state.model_averaging is not None: log_dict = eval_model(evaluator, step.eval_model, prefix=_join_prefix(prefix, "ema")) levanter.tracker.log(log_dict, step=step_count) + metrics_to_write.update(log_dict) + + # Write metrics to file if checkpoint_path is provided + if checkpoint_path is not None and metrics_to_write: + metrics_file = os.path.join(checkpoint_path, "eval_metrics.jsonl") + fs, _, _ = fsspec.get_fs_token_paths(metrics_file) + fs.makedirs(checkpoint_path, exist_ok=True) + with fs.open(metrics_file, "a") as f: + # Convert numpy/jax floats to Python floats for JSON serialization + serializable_metrics = { + k: float(v) if isinstance(v, (np.floating, jnp.floating)) else v + for k, v in metrics_to_write.items() + } + record = {"step": int(step_count), **serializable_metrics} + f.write(json.dumps(record, sort_keys=True) + "\n") return diff --git a/lib/levanter/src/levanter/main/train_lm.py b/lib/levanter/src/levanter/main/train_lm.py index 68f3228e7b..aa45fca792 100644 --- a/lib/levanter/src/levanter/main/train_lm.py +++ b/lib/levanter/src/levanter/main/train_lm.py @@ -199,6 +199,11 @@ def loss_function(model: LmHeadModel, example: LmExample, *, key=None): if len(tagged_eval_datasets) == 0: logger.warning("No evaluation datasets provided.") else: + # Write eval metrics to the same directory as checkpoints + checkpoint_path = None + if config.trainer.checkpointer is not None: + checkpoint_path = config.trainer.checkpointer.expanded_path(trainer.run_id) + cb = levanter.eval.cb_tagged_lm_evaluate( EvalBatch, tagged_eval_datasets, @@ -207,6 +212,7 @@ def loss_function(model: LmHeadModel, example: LmExample, *, key=None): compute_axis_mapping, max_eval_examples_per_ds, mp=config.trainer.mp, + checkpoint_path=checkpoint_path, ) trainer.add_hook(cb, every=config.trainer.steps_per_eval) diff --git a/lib/levanter/src/levanter/tracker/wandb.py b/lib/levanter/src/levanter/tracker/wandb.py index 2063f035ed..81333b0ca2 100644 --- a/lib/levanter/src/levanter/tracker/wandb.py +++ b/lib/levanter/src/levanter/tracker/wandb.py @@ -36,7 +36,7 @@ class WandbTracker(Tracker): name: str = "wandb" run: WandbRun - def __init__(self, run: Optional[WandbRun]): + def __init__(self, run: Optional[WandbRun], replicate_path: Optional[str] = None): import wandb if run is None: @@ -52,6 +52,7 @@ def __init__(self, run: Optional[WandbRun]): self.run = run self._last_warning_step = -500 + self._replicate_path = replicate_path def log_hyperparameters(self, hparams: dict[str, Any]): self.run.config.update(_convert_value_to_loggable_rec(hparams), allow_val_change=True) @@ -100,8 +101,28 @@ def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optio def finish(self): logger.info("Finishing wandb run...") + self._write_replicate_file() self.run.finish() + def _write_replicate_file(self): + if self._replicate_path is None: + return + + import json + + import fsspec + + metrics_file = f"{self._replicate_path}/tracker_metrics.jsonl" + fs, _, _ = fsspec.get_fs_token_paths(metrics_file) + fs.makedirs(self._replicate_path, exist_ok=True) + + with fs.open(metrics_file, "w") as f: + record = { + "config": _convert_value_to_loggable_rec(dict(self.run.config)), + "summary": _convert_value_to_loggable_rec(dict(self.run.summary)), + } + f.write(json.dumps(record, sort_keys=True, default=str) + "\n") + def _convert_value_to_loggable_rec(value: Any): if isinstance(value, (list, tuple)): @@ -113,6 +134,13 @@ def _convert_value_to_loggable_rec(value: Any): return value.item() else: return np.array(value) + elif isinstance(value, np.ndarray): + if value.ndim == 0: + return value.item() + else: + return value.tolist() + elif isinstance(value, np.generic): + return value.item() elif isinstance(value, Histogram): import wandb @@ -160,6 +188,9 @@ class WandbConfig(TrackerConfig): save_xla_dumps: bool = False """If True, will save the XLA code to wandb (as configured by XLA_FLAGS). This is useful for debugging.""" + replicate_path: Optional[str] = None + """If set, write config and summary to this path (local or GCS) on finish().""" + def init(self, run_id: Optional[str]) -> WandbTracker: import wandb @@ -240,7 +271,7 @@ def init(self, run_id: Optional[str]) -> WandbTracker: wandb.summary["num_hosts"] = jax.process_count() # type: ignore wandb.summary["backend"] = jax.default_backend() # type: ignore - return WandbTracker(r) + return WandbTracker(r, replicate_path=self.replicate_path) def _git_settings(self): other_settings = dict() diff --git a/lib/marin/pyproject.toml b/lib/marin/pyproject.toml index adfd352b9e..deb2d449a3 100644 --- a/lib/marin/pyproject.toml +++ b/lib/marin/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "google-cloud-storage", "google-cloud-storage-transfer", "jax==0.8.0", # vllm-tpu currently requires this exact version + "jaxopt>=0.8.3", "haliax", "levanter[serve]", "lm-eval@git+https://github.com/stanford-crfm/lm-evaluation-harness@d5e3391f22cde186c827674d5c3ec7c5f4fe0cab", @@ -37,6 +38,7 @@ dependencies = [ "numpy", "openai", "pandas", + "plotly", "pyarrow>=22", "ray==2.53.0", "regex", @@ -66,6 +68,8 @@ test = [ # need this for integration tests "pip", "openai-responses", + # for scaling law plotting tests + "plotly", ] lint = [ "ruff==0.14.3", diff --git a/lib/marin/src/marin/processing/tokenize/__init__.py b/lib/marin/src/marin/processing/tokenize/__init__.py index 3413b817f1..0706bc9d49 100644 --- a/lib/marin/src/marin/processing/tokenize/__init__.py +++ b/lib/marin/src/marin/processing/tokenize/__init__.py @@ -15,6 +15,7 @@ from .data_configs import ( TokenizerStep, add_validation_sets_to_mixture, + get_vocab_size_for_tokenizer, lm_data_config, lm_mixture_data_config, mixture_for_evaluation, diff --git a/lib/marin/src/marin/processing/tokenize/data_configs.py b/lib/marin/src/marin/processing/tokenize/data_configs.py index 4076f9c01b..23c0f78523 100644 --- a/lib/marin/src/marin/processing/tokenize/data_configs.py +++ b/lib/marin/src/marin/processing/tokenize/data_configs.py @@ -34,6 +34,14 @@ logger = logging.getLogger(__name__) +_KNOWN_VOCAB_SIZES: dict[str, int] = { + "EleutherAI/gpt-neox-20b": 50_257, + "meta-llama/Meta-Llama-3.1-8B": 128_256, + "stanford-crfm/marin-tokenizer": 128_256, + "meta-llama/Llama-2-7b": 32_000, + "gpt2": 50_257, +} + def step_to_lm_mixture_component(step: TokenizerStep | TokenizeConfig, include_raw_paths: bool) -> LMDatasetSourceConfig: """ @@ -333,6 +341,24 @@ def _load_tokenizer(tokenizer_name: str) -> transformers.PreTrainedTokenizer: return load_tokenizer_with_backoff(tokenizer_name) +@lru_cache(maxsize=128) +def get_vocab_size_for_tokenizer(tokenizer_name: str) -> int: + """Return the vocabulary size for a tokenizer name. + + Args: + tokenizer_name: HuggingFace tokenizer name or path. + + Returns: + Vocabulary size for the tokenizer. + """ + resolved_name = unwrap_versioned_value(tokenizer_name) + if resolved_name in _KNOWN_VOCAB_SIZES: + return _KNOWN_VOCAB_SIZES[resolved_name] + + tokenizer = _load_tokenizer(resolved_name) + return len(tokenizer) + + def _are_tokenizers_equivalent(tokenizer1: str, tokenizer2: str) -> bool: """Compare two tokenizers by loading them and comparing their vocabularies and token IDs""" tokenizer1 = unwrap_versioned_value(tokenizer1) diff --git a/lib/marin/src/marin/scaling_laws/__init__.py b/lib/marin/src/marin/scaling_laws/__init__.py index 731b4c72e7..cfeb053bd0 100644 --- a/lib/marin/src/marin/scaling_laws/__init__.py +++ b/lib/marin/src/marin/scaling_laws/__init__.py @@ -11,3 +11,54 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from marin.scaling_laws.isoflop_analysis import ( + DEFAULT_BUDGETS, + DEFAULT_EVAL_METRIC_KEY, + DEFAULT_SEQ_LEN, + CandidateConfig, + FitScalingLawsResult, + IsoFlopRecord, + MinimaRecord, + ModelConfiguration, + QuadraticFitCoeffs, + ScalingFit, + ScalingRecipe, + fit_scaling_laws, + predict_optimal_config, + round_flops_to_bucket, +) +from marin.scaling_laws.tpu_utils import ( + pick_v5p_type, +) +from marin.scaling_laws.scaling_plots import ( + create_isoflop_plot, + create_scaling_plot, + save_plots, + upload_plots_to_wandb, +) + +__all__ = [ + # Constants + "DEFAULT_BUDGETS", + "DEFAULT_EVAL_METRIC_KEY", + "DEFAULT_SEQ_LEN", + # Data classes and Protocols + "CandidateConfig", + "FitScalingLawsResult", + "IsoFlopRecord", + "MinimaRecord", + "ModelConfiguration", + "QuadraticFitCoeffs", + "ScalingFit", + "ScalingRecipe", + # Functions + "create_isoflop_plot", + "create_scaling_plot", + "fit_scaling_laws", + "pick_v5p_type", + "predict_optimal_config", + "round_flops_to_bucket", + "save_plots", + "upload_plots_to_wandb", +] diff --git a/lib/marin/src/marin/scaling_laws/create_ladder_suite.py b/lib/marin/src/marin/scaling_laws/create_ladder_suite.py deleted file mode 100644 index 5f147e7aef..0000000000 --- a/lib/marin/src/marin/scaling_laws/create_ladder_suite.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright 2025 The Marin Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Creates a suite of runs for scaling laws- based on https://arxiv.org/pdf/2412.04403 and https://github.com/marin-community/marin/issues/646. -""" - -import dataclasses -import logging -from collections.abc import Sequence - -from fray.cluster import ResourceConfig -from levanter.data.text import LMMixtureDatasetConfig -from levanter.models.llama import LlamaConfig - -from experiments.defaults import default_train -from experiments.evals.task_configs import CORE_TASKS_PLUS_MMLU -from experiments.llama import llama_1_4b -from experiments.simple_train_config import SimpleTrainConfig -from marin.execution.executor import ExecutorStep, InputName - -DEFAULT_MODEL_CONFIG = LlamaConfig( - max_seq_len=4096, - hidden_dim=2048, - intermediate_dim=7168, - num_heads=16, - num_kv_heads=8, - num_layers=16, -) - -WS_EMA_DEFAULT_TRAIN_CONFIG = SimpleTrainConfig( - resources=ResourceConfig.with_tpu("v4-128", slice_count=1), - train_batch_size=1024, - learning_rate=1e-3, # placeholder, this will be replaced in the scaling law suite - weight_decay=0.1, - # https://arxiv.org/pdf/2412.04403 gets 4 points per run. this gives us 5 - num_train_steps=50000, # 4096 * 1024 * 50000 = ~200B tokens - warmup=1000, # initial warmup - decay=0.0, # no decay - lr_schedule="constant", - ema_beta=0.995, - steps_per_eval=500, - steps_per_task_eval=500, -) - - -def scaling_law_suite( - sweep_name: str, - tokenized: InputName | ExecutorStep | LMMixtureDatasetConfig, - widths: Sequence[int] = (512, 768, 1024, 1536, 2048), - base_model_config: LlamaConfig = llama_1_4b, - tags: Sequence[str] = (), - *, - intermediate_scale: float = 4, - training_config: SimpleTrainConfig = WS_EMA_DEFAULT_TRAIN_CONFIG, - base_lr: float = 3e-4 * 4096, - max_lr: float = 5e-3, -) -> Sequence[ExecutorStep]: - """ - Provides width-wise scaling suite using WSD-S (or other) training configurations. - - Assumptions (consistent with llama 3): - * 128 head_dim - * 8 key-value heads unless that doesn't work with head_dim = 128 - * intermediate_dim = _round_to_multiple(intermediate_scale * width, 128) - * all widths are divisible by 128 - * peak lr is scaled to be base_lr / width, but clamped to max_lr - - Args: - sweep_name: prefix for the sweep name. runs will be named {sweep_name}-{width}-{hash} - base_model_config: base model configuration. Sweep will be generated by varying the width. - tokenized: input data for training - widths: range of widths to sweep over - training_config: training configuration - - References: - * default widths are from https://arxiv.org/pdf/2412.04403 table 1 (plus 512) - * intermediate scale is 4; should be 8 based on https://arxiv.org/pdf/2412.04403 table 1, - but we ultimately decided to go with a smaller value based on - https://arxiv.org/pdf/2407.21783 table 3 since 8 seemed large compared to - other works. - * base_lr is based on llama 3 (https://arxiv.org/pdf/2407.21783 table 3) - * max_lr is a reasonable value that is not too high - * default model config (1_4b) gives the number of layers used in https://arxiv.org/pdf/2412.04403 table 1 - * lr scaling is based on µP/µTransfer: https://arxiv.org/pdf/2203.03466 where generally speaking, lr should - be scaled down by the width of the model. - """ - - steps = [] - for w in widths: - intermediate_dim = _round_to_multiple(intermediate_scale * w, 128) - head_size = 128 # keeping this 128 means we can use splash attention - num_heads = w // head_size - num_kv_heads = min(num_heads, 8) - assert num_heads * head_size == w, f"Number of heads must divide width: {w} % {head_size} != 0" - - # if num_kv_heads doesn't divide num_heads, we need to adjust num_kv_heads - if num_heads % num_kv_heads != 0: - num_kv_heads = num_heads - - model_config = dataclasses.replace( - base_model_config, - hidden_dim=w, - intermediate_dim=intermediate_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - ) - - lr = min(base_lr / w, max_lr) - training_config = dataclasses.replace(training_config, learning_rate=lr) - - logging.info(f"Creating training step for {sweep_name}-{w} with width {w} and lr {lr}") - - steps.append( - default_train( - name=f"{sweep_name}-{w}", - tokenized=tokenized, - model_config=model_config, - train_config=training_config, - tags=tags, - eval_harness_tasks=CORE_TASKS_PLUS_MMLU, - ) - ) - return steps - - -def _round_to_multiple(x, multiple): - return int(multiple * round(x / multiple)) diff --git a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py new file mode 100644 index 0000000000..c557b83c72 --- /dev/null +++ b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py @@ -0,0 +1,141 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base infrastructure for eval metrics analysis. + +This module provides utilities for analysis jobs that read tracker_metrics.jsonl +files from completed training runs. +""" + +import json +import logging +import os +from collections.abc import Sequence + +import fsspec +import wandb + +from marin.utilities.wandb_utils import WANDB_ENTITY, WANDB_PROJECT + +logger = logging.getLogger(__name__) + + +def extract_run_name_from_path(path: str) -> str: + """Extract run name (last component) from a checkpoint path. + + E.g., 'gs://bucket/checkpoints/my-run-abc123' -> 'my-run-abc123' + """ + return os.path.basename(path.rstrip("/")) + + +def _backfill_metrics_from_wandb( + checkpoint_path: str, + metrics_file: str, + entity_project: str, +) -> bool: + """ + Backfill tracker_metrics.jsonl from WandB for a training run. + + Writes a single record with config and summary, matching the format + written by WandbTracker.finish() when replicate_path is set. + + Args: + checkpoint_path: Path to the checkpoint directory + metrics_file: Full path to where tracker_metrics.jsonl should be written + entity_project: WandB entity/project (format: 'entity/project') + + Returns: + True if backfill succeeded, False otherwise + """ + try: + run_id = extract_run_name_from_path(checkpoint_path) + logger.info(f"Attempting to backfill metrics for run_id: {run_id}") + + api = wandb.Api() + run = api.run(f"{entity_project}/{run_id}") + + # Build record matching WandbTracker._write_replicate_file format + record = { + "config": dict(run.config), + "summary": {k: v for k, v in run.summary.items() if not k.startswith("_")}, + } + + fs, _, _ = fsspec.get_fs_token_paths(metrics_file) + fs.makedirs(os.path.dirname(metrics_file), exist_ok=True) + + with fs.open(metrics_file, "w") as f: + f.write(json.dumps(record, sort_keys=True, default=str) + "\n") + + logger.info(f"Successfully backfilled metrics to {metrics_file}") + return True + + except Exception as e: + logger.warning(f"Failed to backfill metrics from WandB: {e}") + return False + + +def read_eval_records( + training_runs: Sequence[str], + metrics_filename: str = "tracker_metrics.jsonl", + wandb_entity_project: str = f"{WANDB_ENTITY}/{WANDB_PROJECT}", +) -> list[dict]: + """Read raw eval metrics from training runs. + + This is the shared utility that all analysis subtypes use to load metrics. + It handles reading JSONL files and WandB backfill when files are missing. + + Args: + training_runs: List of training run output paths. + metrics_filename: Name of the metrics file within each checkpoint directory. + wandb_entity_project: WandB entity/project to query for backfill (format: 'entity/project'). + + Returns: + List of raw records, each containing config, summary, run_index, and run_path. + """ + all_records = [] + + for i, run_path in enumerate(training_runs): + metrics_file = os.path.join(run_path, metrics_filename) + + fs, _, _ = fsspec.get_fs_token_paths(metrics_file) + + if not fs.exists(metrics_file): + logger.info(f"{metrics_file} does not exist, attempting to backfill from WandB...") + + success = _backfill_metrics_from_wandb( + checkpoint_path=run_path, + metrics_file=metrics_file, + entity_project=wandb_entity_project, + ) + if not success: + raise RuntimeError( + f"Backfill from WandB failed for run {i} (path={run_path}, metrics_file={metrics_file})" + ) + + with fs.open(metrics_file, "r") as f: + for line in f: + line = line.strip() + if not line: + continue + + record = json.loads(line) + record["run_index"] = i + record["run_path"] = run_path + all_records.append(record) + + if not all_records: + logger.warning("No eval metrics found in any training runs") + + logger.info(f"Loaded {len(all_records)} evaluation records from {len(training_runs)} runs") + return all_records diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py new file mode 100644 index 0000000000..a5d57cc50e --- /dev/null +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -0,0 +1,446 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""IsoFLOP analysis for finding compute-optimal training configurations. + +This module provides the core data types and analysis functions for IsoFLOP +scaling law analysis. It is intentionally schema-agnostic - experiment code +should transform raw metrics into IsoFlopRecord before calling these functions. + +Key types: +- IsoFlopRecord: The contract for a single training run's metrics +- FitScalingLawsResult: Output from fit_scaling_laws() +- CandidateConfig: Complete training configuration (model, optimizer, schedule) + +Key functions: +- fit_scaling_laws(records): Fit scaling laws from typed records +- predict_optimal_config(): Predict optimal training config for a target budget +""" + +import logging +import math +from collections.abc import Iterator +from dataclasses import dataclass +from typing import NamedTuple, Protocol + +import jax.numpy as jnp +from jaxopt import ScipyMinimize + +from levanter.optim.config import OptimizerConfig + +logger = logging.getLogger(__name__) + +# ---------------- Constants ---------------- + +# Paloma is a standard LLM evaluation benchmark. C4-en BPB (bits-per-byte) is a +# common loss metric that measures model perplexity on the C4 English dataset. +# See: https://arxiv.org/abs/2312.10523 +DEFAULT_EVAL_METRIC_KEY = "eval/paloma/c4_en/bpb" +DEFAULT_SEQ_LEN = 4096 + +# ---------------- IsoFLOP Sweep Constants ---------------- +# Budgets in training FLOPs (includes 3x multiplier for forward + backward pass). +# This matches how FLOPs are tracked in WandB via Levanter's log_performance_stats. +DEFAULT_BUDGETS: tuple[float, ...] = (1e18, 3e18, 6e18, 1e19, 3e19, 6e19, 1e20) + + +# ---------------- Typed Tuples ---------------- + + +class ScalingFit(NamedTuple): + """Scaling law fit parameters for D* ~ A * C^alpha (optimal tokens ~ compute^alpha).""" + + alpha: float + """Exponent in scaling law.""" + + A: float + """Coefficient in scaling law.""" + + +class QuadraticFitCoeffs(NamedTuple): + """Quadratic fit coefficients for loss = a * log10(tokens)^2 + b * log10(tokens) + c.""" + + a: float + """Quadratic coefficient.""" + + b: float + """Linear coefficient.""" + + c: float + """Constant term.""" + + token_min: float + """Minimum token count used for fitting.""" + + token_max: float + """Maximum token count used for fitting.""" + + +# ---------------- IsoFlopRecord ---------------- + + +@dataclass +class IsoFlopRecord: + """A single training run record for isoflop analysis. + + This is the contract between experiment code (which knows how to extract + these fields from raw metrics) and the analysis code (which just does math). + """ + + tokens: float + """Total tokens trained on.""" + + metric: float + """Evaluation metric value (e.g., bits-per-byte from Paloma).""" + + flops: float + """Total training FLOPs (bucketed).""" + + params: float + """Parameter count.""" + + label: str + """Experiment label for grouping (e.g., 'nemo', 'dclm').""" + + +# ---------------- Model Configuration Protocol ---------------- + + +class ModelConfiguration(Protocol): + """Protocol for model configs used in scaling law calculations. + + Any model config that implements these methods can be used with the + scaling law functions. This allows the library to be model-agnostic + while still working with LlamaConfig, QwenConfig, etc. + """ + + def flops_per_token(self, vocab_size: int, seq_len: int) -> float: + """Return FLOPs per token for this model configuration.""" + ... + + def total_trainable_params(self, vocab_size: int) -> int: + """Return total trainable parameter count for this model configuration.""" + ... + + +# ---------------- Candidate Config ---------------- + + +@dataclass +class CandidateConfig: + """Complete training configuration for a scaling law candidate. + + Contains everything needed to run a training job: + - model_config: The model architecture + - optimizer_config: Optimizer with learning rate, beta2, etc. + - batch_size: Training batch size + - train_steps: Number of training steps + - tokens: Total tokens to train on (batch_size * train_steps * seq_len) + - flops_budget: The compute budget this config was generated for + """ + + model_config: ModelConfiguration + """Model configuration for this candidate.""" + + optimizer_config: OptimizerConfig + """Optimizer configuration with learning rate, weight decay, etc.""" + + batch_size: int + """Training batch size.""" + + train_steps: int + """Number of training steps.""" + + tokens: float + """Total tokens to train on.""" + + flops_budget: float + """Compute budget this config was generated for.""" + + +class ScalingRecipe(Protocol): + """Protocol defining the interface for scaling law recipes. + + Concrete implementations (e.g., Marin2025Recipe) should implement these + model-specific methods. The recipe owns the vocab_size, which is derived + from the tokenizer choice. + """ + + name: str + """Name identifying this recipe (e.g., 'marin-2025').""" + + vocab_size: int + """Vocabulary size for the tokenizer used with this recipe.""" + + def estimate_memory_bytes(self, candidate: CandidateConfig) -> int: + """Estimate memory usage in bytes for training a candidate configuration.""" + ... + + def candidates_for_budget( + self, + budget: float, + seq_len: int = DEFAULT_SEQ_LEN, + ) -> Iterator[CandidateConfig]: + """Yield valid candidate training configs for the given FLOP budget. + + This is the main entry point for generating training configurations. + Implementations should iterate over model architectures and yield + complete CandidateConfig objects with model, optimizer, batch size, + and training steps all configured. + """ + ... + + +# ---------------- Typed Records ---------------- + + +@dataclass +class MinimaRecord: + """Model-agnostic record of optimal configuration found at a specific (label, flops) point.""" + + label: str + flops: float + optimal_tokens: float + loss_at_optimal: float + optimal_params: float + scaling_alpha: float | None = None + scaling_A: float | None = None + + +@dataclass +class FitScalingLawsResult: + """Result from fit_scaling_laws containing minima, scaling fits, and fit curves.""" + + minima_records: list[MinimaRecord] + """List of optimal configurations found at each (label, flops) point.""" + + scaling_fits: dict[str, ScalingFit] + """Per-label scaling fits: {label: ScalingFit} for N* ~ A * C^alpha.""" + + fit_curves: dict[tuple[str, float], QuadraticFitCoeffs] + """Quadratic fit coefficients {(label, flops): QuadraticFitCoeffs} for plotting.""" + + +# ---------------- Candidate Config Generation ---------------- + + +def round_to_power_of_two(x: float) -> int: + """Round ``x`` to the nearest power of two.""" + if x <= 1: + return 1 + return 2 ** math.ceil(math.log2(x)) + + +def round_flops_to_bucket(flops: float, base: float = 1.1) -> float: + """Round FLOP count to the nearest power of base. + + Args: + flops: FLOP count to round. + base: Base for the power buckets (default 1.1 for ~10% buckets). + """ + if flops <= 0: + return flops + + k = math.log(flops) / math.log(base) + return base ** round(k) + + +def robust_quad_logx(x: jnp.ndarray, y: jnp.ndarray, delta: float = 1.0) -> tuple[float, float, float]: + """Fit a robust quadratic in log10(x) space using Huber loss. + + Log10 space is used because sweeps are defined in powers of 10 (scientific + notation like 1e18, 1e19, 3e19), so log10 produces evenly-spaced points. + + The Huber loss provides robustness to outliers compared to standard least squares. + + Args: + x: Input array (e.g., token counts). Must be positive. + y: Output array (e.g., loss values). + delta: Huber loss threshold. Residuals larger than delta use linear loss. + + Returns: + Tuple (a, b, c) of coefficients for: loss = a * log10(x)^2 + b * log10(x) + c + """ + L = jnp.log10(x) + + def huber(residual): + abs_r = jnp.abs(residual) + quad = 0.5 * residual**2 + linear = delta * (abs_r - 0.5 * delta) + return jnp.where(abs_r <= delta, quad, linear) + + def objective(params): + a, b, c = params + pred = a * L**2 + b * L + c + residuals = y - pred + return jnp.sum(huber(residuals)) + + opt = ScipyMinimize(fun=objective, method="BFGS", value_and_grad=False) + init = jnp.array(jnp.polyfit(L, y, 2)) if len(L) >= 3 else jnp.array([0.0, *jnp.polyfit(L, y, 1)]) + result = opt.run(init_params=init).params + return float(result[0]), float(result[1]), float(result[2]) + + +# ---------------- Core Analysis ---------------- + + +def fit_scaling_laws( + records: list[IsoFlopRecord], +) -> FitScalingLawsResult: + """Fit scaling laws and extract optimal configurations. + + Args: + records: List of IsoFlopRecord with tokens, metric, flops, params, label, batch_size. + + Returns: + FitScalingLawsResult containing minima_records, scaling_fits, and fit_curves. + """ + if not records: + return FitScalingLawsResult(minima_records=[], scaling_fits={}, fit_curves={}) + + # Get unique labels preserving order of first appearance + datasets = list(dict.fromkeys(r.label for r in records)) + + # Get unique flop buckets + buckets = sorted(set(r.flops for r in records)) + + minima_records: list[MinimaRecord] = [] + fit_curves: dict[tuple[str, float], QuadraticFitCoeffs] = {} + + # Fit quadratic for each (label, budget) and find minima + for lab in datasets: + for C in buckets: + sub = sorted( + [r for r in records if r.flops == C and r.label == lab], + key=lambda r: r.tokens, + ) + if not sub: + continue + + # Robust quadratic fit in log10(tokens) + # Use float64 to avoid int32 overflow for token counts > 2^31 + tokens_array = jnp.array([r.tokens for r in sub], dtype=jnp.float64) + a, b, c = robust_quad_logx( + tokens_array, + jnp.array([r.metric for r in sub], dtype=jnp.float64), + ) + # Store coefficients along with token range used for fitting + fit_curves[(lab, C)] = QuadraticFitCoeffs(a, b, c, float(tokens_array.min()), float(tokens_array.max())) + + if a == 0: + continue + + log_D_opt = -b / (2 * a) + D_star = float(10**log_D_opt) + metric_opt = float(a * log_D_opt**2 + b * log_D_opt + c) + + # Find record with tokens closest to optimal + nearest_record = min(sub, key=lambda r: abs(r.tokens - D_star)) + + minima_records.append( + MinimaRecord( + label=lab, + flops=float(C), + optimal_tokens=D_star, + loss_at_optimal=metric_opt, + optimal_params=nearest_record.params, + ) + ) + + # Fit scaling law D* ~ A * C^alpha per dataset (optimal tokens ~ compute^alpha) + scaling_fits: dict[str, ScalingFit] = {} + by_lab: dict[str, list[MinimaRecord]] = {} + for rec in minima_records: + by_lab.setdefault(rec.label, []).append(rec) + + for lab in datasets: + recs = by_lab.get(lab, []) + if len(recs) < 2: + continue + + recs = sorted(recs, key=lambda r: r.flops) + Cs = jnp.array([r.flops for r in recs]) + Ds = jnp.array([r.optimal_tokens for r in recs]) + + alpha, logA = jnp.polyfit(jnp.log10(Cs), jnp.log10(Ds), 1) + A = float(10**logA) + alpha = float(alpha) + scaling_fits[lab] = ScalingFit(alpha, A) + + # Augment minima records with scaling fit params + for rec in recs: + rec.scaling_alpha = alpha + rec.scaling_A = A + + return FitScalingLawsResult( + minima_records=minima_records, + scaling_fits=scaling_fits, + fit_curves=fit_curves, + ) + + +# ---------------- Predict Optimal Config ---------------- + + +def predict_optimal_config( + scaling_fits: dict[str, ScalingFit], + target_flops: float, + label: str, + recipe: ScalingRecipe, + seq_len: int = DEFAULT_SEQ_LEN, +) -> CandidateConfig | None: + """Predict optimal training config for a target compute budget using fitted scaling laws. + + This implements IsoFLOP Approach 2 from the Chinchilla paper: + 1. D_opt (optimal tokens) is found empirically at each compute budget by fitting + parabolas to actual loss values and finding the minimum. + 2. D_opt ~ A * C^alpha is fitted from those empirical minima. + 3. Given D_opt and C, N_opt (optimal params) is derived as C/(6D), so no + separate alpha fit for params is needed. + + Args: + scaling_fits: Dict of {label: ScalingFit} from scaling ladder result. + target_flops: Target compute budget in FLOPs. + label: Dataset/experiment label to use for scaling fit. + recipe: ScalingRecipe with architecture/hyperparameter settings (includes vocab_size). + seq_len: Sequence length for training. + + Returns: + CandidateConfig for the predicted optimal, or None if label not in fits + or no valid candidates found. + """ + if label not in scaling_fits: + logger.warning(f"Label '{label}' not found in scaling fits") + return None + + alpha, A = scaling_fits[label] + optimal_tokens = A * (target_flops**alpha) + + logger.info(f"Predicted optimal tokens for {target_flops:.2e} FLOPs: {optimal_tokens:.2e}") + + candidates = list(recipe.candidates_for_budget(target_flops, seq_len)) + + if not candidates: + logger.warning(f"No valid candidates found for budget {target_flops:.2e}") + return None + + # Find candidate with tokens >= optimal_tokens, closest to optimal + best = min(candidates, key=lambda c: c.tokens - optimal_tokens if c.tokens >= optimal_tokens else float("inf")) + if best.tokens < optimal_tokens: + best = max(candidates, key=lambda c: c.tokens) + + params = best.model_config.total_trainable_params(recipe.vocab_size) + logger.info(f"Selected config: N={params:.2e}, tokens={best.tokens:.2e} (optimal: {optimal_tokens:.2e})") + + return best diff --git a/lib/marin/src/marin/scaling_laws/scaling_laws.py b/lib/marin/src/marin/scaling_laws/scaling_laws.py deleted file mode 100644 index 7d602ac580..0000000000 --- a/lib/marin/src/marin/scaling_laws/scaling_laws.py +++ /dev/null @@ -1,218 +0,0 @@ -# Copyright 2025 The Marin Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file contains functions for setting scaling law configurations, wrapper functions to call the relevant -regressions/predictions, and creating a WandB report with the results. The code here implements the function -(see `run_scaling_law_analysis`, that will be called by an ExecutorStep in the scaling laws analysis pipeline. - -Our objective is to predict the accuracy of a larger target model on a specific benchmark. -This prediction is done through a two-step modeling process using (N, D) data from various smaller models: -- we first fit a power-law model to predict the task loss from the number of parameters and tokens. -- then, we fit a sigmoidal model to predict task accuracy from the task loss. - -Reference: - Establishing Task Scaling Laws via Compute-Efficient Model Ladders - Bhagia et. al 2024 - https://arxiv.org/pdf/2412.04403. -""" - -from collections.abc import Sequence -from dataclasses import dataclass, field - -import numpy as np -import wandb - -from marin.execution.executor import ExecutorStep -from marin.scaling_laws.utils import ( - ProjectionPoint, - get_default_projection_points, - plot_actual_vs_predicted, - plot_scaling_projections, -) -from marin.utilities.wandb_utils import WANDB_ENTITY, WANDB_PROJECT - - -@dataclass(frozen=True) -class ScalingLawConfig: - name: str - """name of the scaling law analysis or config (used for the report name)""" - - ladder_model_steps: Sequence[ExecutorStep | str] - """list of (smaller model) steps or wandb run ids to be used as input for scaling laws""" - - pred_model_step: ExecutorStep | str - """executor step or wandb run id for the larger model to make predictions for""" - - projection_points: list[ProjectionPoint] | None = None - """Points to project to, consisting of number of parameters and tokens""" - - task_losses: Sequence[str] = field(default_factory=lambda: ["eval/paloma/c4_en/bpb"]) - """task losses to predict for scaling laws (eg. c4en bpb)""" - - task_accuracies: Sequence[str] | None = None - """task accuracy to predict for the larger model (eg. hellaswag accuracy)""" - - use_log_for_ND: bool = True - """whether to use log space for N,D in scaling laws""" - - normalize_ND: bool = True - """whether to normalize N,D in scaling laws""" - - count_embedding_params: bool = False - """whether to count embedding parameters in scaling laws""" - - entity: str = WANDB_ENTITY - project: str = "marin" - - def __post_init__(self): - # Set default projection points if none provided - if self.projection_points is None: - object.__setattr__( - self, - "projection_points", - get_default_projection_points(count_embedding_params=self.count_embedding_params), - ) - - -def get_wandb_run_id_from_step(step: ExecutorStep) -> str: - """ - Get the wandb run id from a given ExecutorStep. - """ - return step.config.trainer.tracker.id - - -def run_scaling_law_analysis(config: ScalingLawConfig) -> None: - """ - Analyze scaling laws for a given task loss and multiple accuracy metrics. - """ - from marin.scaling_laws.utils import fit_scaling_laws - - input_run_ids = [ - get_wandb_run_id_from_step(step) if isinstance(step, ExecutorStep) else step - for step in config.ladder_model_steps - ] - - pred_run_id = None - if config.pred_model_step: - pred_run_id = ( - get_wandb_run_id_from_step(config.pred_model_step) - if isinstance(config.pred_model_step, ExecutorStep) - else config.pred_model_step - ) - - projections, predictions = fit_scaling_laws( - runs=input_run_ids, - loss_metrics=config.task_losses, - accuracy_metrics=config.task_accuracies, - entity=config.entity, - project=config.project, - pred_run=pred_run_id, - projection_points=config.projection_points, - count_embedding_params=config.count_embedding_params, - use_log_for_ND=config.use_log_for_ND, - normalize_ND=config.normalize_ND, - ) - - log_and_create_report( - projections=projections, - points=config.projection_points, - predictions=predictions, - input_run_ids=input_run_ids, - pred_run_id=pred_run_id, - scaling_law_config=config, - ) - - -def log_and_create_report( - projections: dict[str, np.ndarray], - points: list[ProjectionPoint] | None, - predictions: tuple[dict, dict, np.ndarray, np.ndarray] | None, - input_run_ids: list, - pred_run_id: str | None, - scaling_law_config: ScalingLawConfig, - wandb_project: str = "marin-scaling-laws", - wandb_entity: str = WANDB_ENTITY, - wandb_source_project: str = WANDB_PROJECT, -): - """ - Logs scaling law analysis creates a concise WandB report with plots and info about runs. - """ - # Initialize WandB run - run = wandb.init( - project=wandb_project, - entity=wandb_entity, - name=f"""Scaling Law Report: {pred_run_id if pred_run_id else 'projection'}-{scaling_law_config.name}""", - tags=["scaling_laws"], - config={ - "input_runs": input_run_ids, - "prediction_run": pred_run_id, - }, - reinit=True, - ) - - plots = {} - - # Log projections - if points: - for loss_name, projection in projections.items(): - figure = plot_scaling_projections(projection, points) - plots[f"Projection - {loss_name}"] = wandb.Image(figure) - - # Log predictions if available - if predictions: - loss_results, accuracy_results, loss_tokens, acc_tokens = predictions - - if loss_results: - for loss_name, (actual_loss, predicted_loss) in loss_results.items(): - figure = plot_actual_vs_predicted( - actual_loss.tolist(), - predicted_loss.tolist(), - title=f"Actual vs Predicted {loss_name}", - task_metric=loss_name, - tokens=loss_tokens, - ) - plots[f"Task Loss - {loss_name}"] = wandb.Image(figure) - - if accuracy_results: - for metric, (actual_acc, predicted_acc) in accuracy_results.items(): - figure = plot_actual_vs_predicted( - actual_acc.tolist(), - predicted_acc.tolist(), - title=f"Actual vs Predicted {metric}", - task_metric=metric, - tokens=acc_tokens, - ) - plots[f"Task Accuracy - {metric}"] = wandb.Image(figure) - - # Log all plots - wandb.log(plots) - - # Info about runs and links - input_run_links = [ - f"https://wandb.ai/{wandb_entity}/{wandb_source_project}/runs/{run_id}" for run_id in input_run_ids - ] - prediction_run_link = ( - f"https://wandb.ai/{wandb_entity}/{wandb_source_project}/runs/{pred_run_id}" if pred_run_id else None - ) - run.summary.update( - { - "Input Runs": input_run_links, - "Prediction Run": prediction_run_link, - "Task Losses": scaling_law_config.task_losses, - "Task Accuracies": scaling_law_config.task_accuracies, - } - ) - - wandb.finish() diff --git a/lib/marin/src/marin/scaling_laws/scaling_plots.py b/lib/marin/src/marin/scaling_laws/scaling_plots.py new file mode 100644 index 0000000000..6680d7307e --- /dev/null +++ b/lib/marin/src/marin/scaling_laws/scaling_plots.py @@ -0,0 +1,329 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Visualization functions for scaling ladder analysis. + +This module provides plotting utilities for isoflop analysis results. +All plotly-related code is contained here to keep the core scaling_ladder +module free of visualization dependencies. +""" + +import logging +import os + +import fsspec +import jax.numpy as jnp +import pandas as pd +import plotly.graph_objects as go +import plotly.io as pio + +from marin.scaling_laws.isoflop_analysis import QuadraticFitCoeffs, ScalingFit +from marin.utilities.wandb_utils import WANDB_ENTITY, WANDB_PROJECT + +import wandb + +logger = logging.getLogger(__name__) + +# ---------------- Theme ---------------- +pio.templates.default = "plotly_white" + +# ---------------- Visual Constants ---------------- +PALETTE = [ + "#1877F2", + "#F0701A", + "#5A24C7", + "#E42C97", + "#00487C", + "#0EAC96", + "#AB76FF", + "#B50550", + "#0099E6", + "#22085F", + "#783301", +] + +MARKERS = [ + "circle", + "square", + "cross", + "x", + "triangle-up", + "triangle-down", + "triangle-left", + "triangle-right", + "pentagon", + "hexagon", + "hexagon2", + "star", + "star-triangle-up", + "star-triangle-down", + "star-square", + "star-diamond", + "hourglass", + "bowtie", +] + +DASHES = ["solid", "dot", "dash", "longdash", "dashdot", "longdashdot"] + +_MIN_MARKER = dict(symbol="diamond", size=10, color="#000000") +_SCALE_MARKER = dict(symbol="circle", size=9, color=PALETTE[0]) +_SCALE_LINE = dict(dash="dot", width=2, color=PALETTE[0]) + + +def create_isoflop_plot( + df: pd.DataFrame, + minima_records: list, + fit_curves: dict[tuple[str, float], QuadraticFitCoeffs], +) -> go.Figure: + """Create the IsoFLOP plot showing loss vs tokens for each compute budget. + + Args: + df: DataFrame with columns: tokens, loss, flops, params, name, label + minima_records: List of MinimaRecord with optimal config info per (label, flops) + fit_curves: Dict of {(label, flops): QuadraticFitCoeffs} quadratic fit coefficients + + Returns: + Plotly Figure with the isoflop visualization + """ + if df.empty: + return go.Figure() + + datasets = list(dict.fromkeys(df["label"].tolist())) + + buckets = sorted(df.flops.unique()) + bucket_color = {C: PALETTE[i % len(PALETTE)] for i, C in enumerate(buckets)} + ds_marker = {lab: MARKERS[i % len(MARKERS)] for i, lab in enumerate(datasets)} + + fig = go.Figure() + + # Build lookup for minima + minima_lookup = {(rec.label, rec.flops): rec for rec in minima_records} + + for lab in datasets: + for C in buckets: + sub = df[(df.flops == C) & (df.label == lab)].sort_values("tokens") + if sub.empty: + continue + + # Scatter points + fig.add_trace( + go.Scatter( + x=sub.tokens, + y=sub.loss, + mode="markers", + marker=dict(symbol=ds_marker[lab], color=bucket_color[C], size=8), + name=f"{lab}, {C:.2e} FLOPs", + legendgroup=f"{lab}, {C:.2e}", + hovertemplate=( + "C=%{text:.2e} FLOPs
tokens=%{x:.3e}
" + "loss=%{y:.4f}
params=%{customdata:.3e}" + ), + text=[C] * len(sub), + customdata=sub.params.values, + ) + ) + + # Draw fit curve if available + key = (lab, C) + if key in fit_curves: + a, b, c, token_min, token_max = fit_curves[key] + if a != 0: + Ls = jnp.linspace(jnp.log10(token_min), jnp.log10(token_max), 200) + fig.add_trace( + go.Scatter( + x=10**Ls, + y=a * Ls**2 + b * Ls + c, + mode="lines", + line=dict(color=bucket_color[C], dash="dash", width=2), + showlegend=False, + legendgroup=f"{lab}, {C:.2e}", + ) + ) + + # Draw minimum marker + if key in minima_lookup: + rec = minima_lookup[key] + fig.add_trace( + go.Scatter( + x=[rec.optimal_tokens], + y=[rec.loss_at_optimal], + mode="markers", + marker=_MIN_MARKER, + showlegend=False, + legendgroup=f"{lab}, {C:.2e}", + hovertemplate=( + "Compute-optimal
" + "C=%{text:.2e} FLOPs
tokens=%{x:.3e}
" + "loss=%{y:.4f}
params=%{customdata:.3e}" + ), + text=[C], + customdata=[rec.optimal_params], + ) + ) + + fig.update_layout( + template="plotly_white", + xaxis_type="log", + xaxis_title="Tokens (log scale)", + yaxis_title="Bits Per Byte Validation", + title="Marin IsoFLOP Suite", + width=1000, + height=600, + ) + + return fig + + +def create_scaling_plot( + minima_records: list, + scaling_fits: dict[str, ScalingFit], +) -> go.Figure: + """Create the scaling law fit plot showing N* vs compute budget. + + Args: + minima_records: List of MinimaRecord with optimal config info per (label, flops) + scaling_fits: Dict of {label: ScalingFit} for N* ~ A * C^alpha + + Returns: + Plotly Figure with the scaling fit visualization + """ + if not minima_records: + return go.Figure() + + # Group by label + by_lab = {} + for rec in minima_records: + by_lab.setdefault(rec.label, []).append(rec) + + datasets = list(by_lab.keys()) + + fig = go.Figure() + + for i, lab in enumerate(datasets): + recs = by_lab.get(lab, []) + if not recs: + continue + + recs = sorted(recs, key=lambda r: r.flops) + Cs = jnp.array([r.flops for r in recs]) + Ns = jnp.array([r.optimal_tokens for r in recs]) + + color = PALETTE[i % len(PALETTE)] + dash = DASHES[i % len(DASHES)] + + # Plot minima points + fig.add_trace( + go.Scatter( + x=list(map(float, Cs)), + y=list(map(float, Ns)), + mode="markers", + marker=dict(symbol=_SCALE_MARKER["symbol"], size=_SCALE_MARKER["size"], color=color), + name=f"{lab} minima", + legendgroup=lab, + ) + ) + + # Plot fit line if available + if lab in scaling_fits: + alpha, A = scaling_fits[lab] + Cmin, Cmax = float(Cs.min()), float(Cs.max()) + C_fit = jnp.logspace(jnp.log10(Cmin) - 0.1, jnp.log10(Cmax) + 0.1, 400) + N_fit = A * (C_fit**alpha) + + fig.add_trace( + go.Scatter( + x=list(map(float, C_fit)), + y=list(map(float, N_fit)), + mode="lines", + line=dict(color=color, dash=dash, width=_SCALE_LINE["width"]), + name=f"{lab} fit (a={alpha:.3f})", + legendgroup=lab, + ) + ) + + fig.update_layout( + template="plotly_white", + xaxis_type="log", + yaxis_type="log", + xaxis_title="Compute budget C (FLOPs, log)", + yaxis_title="Optimal tokens N* (log)", + title="Scaling fits per dataset", + ) + + return fig + + +def save_plots( + fig_isoflop: go.Figure, + fig_scaling: go.Figure, + output_path: str, +) -> None: + """Save isoflop and scaling plots to HTML files. + + Args: + fig_isoflop: IsoFLOP plot figure + fig_scaling: Scaling fit plot figure + output_path: Directory path to save plots + """ + fs, _, _ = fsspec.get_fs_token_paths(output_path) + fs.makedirs(output_path, exist_ok=True) + + iso_path = os.path.join(output_path, "isoflop_plot.html") + scaling_path = os.path.join(output_path, "scaling_plot.html") + + with fs.open(iso_path, "w") as f: + f.write(fig_isoflop.to_html()) + logger.info(f"Wrote isoflop plot to {iso_path}") + + with fs.open(scaling_path, "w") as f: + f.write(fig_scaling.to_html()) + logger.info(f"Wrote scaling plot to {scaling_path}") + + +def upload_plots_to_wandb( + fig_isoflop: go.Figure, + fig_scaling: go.Figure, + entity: str = WANDB_ENTITY, + project: str = f"{WANDB_PROJECT}-analysis", + run_name: str = "scaling-ladder-analysis", +) -> None: + """Upload plots to Weights & Biases. + + Args: + fig_isoflop: IsoFLOP plot figure + fig_scaling: Scaling fit plot figure + entity: WandB entity + project: WandB project + run_name: Name for the WandB run + + TODO: Consider extracting a generic wandb-upload utility that takes artifacts + and handles upload logic. This would decouple the plotting logic from WandB + and allow reuse across other analysis tools. + """ + wandb.login() + run = wandb.init( + entity=entity, + project=project, + job_type="scaling-ladder", + name=run_name, + resume="allow", + ) + wandb.log( + { + "isoflop_plot": wandb.Plotly(fig_isoflop), + "scaling_plot": wandb.Plotly(fig_scaling), + } + ) + run.finish() + logger.info("Uploaded plots to WandB") diff --git a/lib/marin/src/marin/scaling_laws/tpu_utils.py b/lib/marin/src/marin/scaling_laws/tpu_utils.py new file mode 100644 index 0000000000..aae9459aec --- /dev/null +++ b/lib/marin/src/marin/scaling_laws/tpu_utils.py @@ -0,0 +1,56 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TPU hardware utilities for memory estimation and slice selection. + +This module provides utilities for estimating memory requirements and +selecting appropriate TPU slice sizes for training runs. +""" + +import math + +# ---------------- TPU v5p Hardware Constants ---------------- +# These constants are specific to TPU v5p pods. + +HBM_PER_CHIP_GIB = 95 +"""High-bandwidth memory per TPU v5p chip in GiB.""" + +CORES_PER_CHIP = 2 +"""Number of cores per TPU v5p chip.""" + +V5P_CORE_OPTIONS = [8, 16, 32, 64, 128, 256, 512, 1024, 2048] +"""Available TPU v5p core configurations (slice sizes).""" + + +def pick_v5p_type(estimated_memory_bytes: int) -> str: + """Select the smallest TPU v5p slice that fits the estimated memory. + + Args: + estimated_memory_bytes: Estimated memory requirement in bytes. + + Returns: + TPU slice name, e.g., "v5p-8" or "v5p-32". + + Raises: + ValueError: If the model is too large for available v5p slices. + """ + chip_bytes = HBM_PER_CHIP_GIB * 1024**3 + chips = math.ceil(estimated_memory_bytes / chip_bytes) + cores_req = chips * CORES_PER_CHIP + + valid = [c for c in V5P_CORE_OPTIONS if c >= cores_req] + if not valid: + raise ValueError(f"Model too large for available v5p slices (need {cores_req} cores).") + + return f"v5p-{min(valid)}" diff --git a/lib/marin/src/marin/scaling_laws/utils.py b/lib/marin/src/marin/scaling_laws/utils.py deleted file mode 100644 index ed669f7300..0000000000 --- a/lib/marin/src/marin/scaling_laws/utils.py +++ /dev/null @@ -1,667 +0,0 @@ -# Copyright 2025 The Marin Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Functions for fitting scaling laws and plotting the results. - -The functions in this file implement the techniques in https://arxiv.org/pdf/2412.04403. - -Our objective is to predict the accuracy of a larger target model on a specific benchmark. -This prediction is done through a two-step modeling process using (N, D) data from various smaller models: -- we first fit a power-law model to predict the task loss from the number of parameters and tokens. -- then, we fit a sigmoidal model to predict the task accuracy from the task loss. - -For further details see the corresponding GitHub issue: https://github.com/marin-community/marin/issues/646. - -To use this code, call fit_scaling_laws() with appropriate arguments. -""" - -from collections.abc import Callable, Sequence -from dataclasses import dataclass -from typing import Any - -import matplotlib.pyplot as plt -import numpy as np -from scipy.optimize import curve_fit, minimize -from scipy.special import huber - -from experiments.llama import compute_num_parameters, llama3_tokenizer_vocab_size - -try: - import pandas as pd -except ImportError: - pd: Any = None - -OPTIMIZATION_TOLERANCE = 1e-10 - -#################################################################################################### -# Power law helpers - - -def power_law_model(params: Sequence[float], N: np.ndarray, D: np.ndarray, use_log_space: bool = True) -> np.ndarray: - """ - Power-law equation: A / N^alpha + B / D^beta + E - - Args: - params: List of parameters [A, B, alpha, beta, E] - N: Number of parameters - D: Number of tokens - use_log_space: Whether to use log space for A and B - """ - if use_log_space: - log_A, log_B, alpha, beta, E = params - A, B = np.exp(log_A), np.exp(log_B) - else: - A, B, alpha, beta, E = params - return A / (N**alpha) + B / (D**beta) + E - - -def power_law_loss( - params: Sequence[float], - N: np.ndarray, - D: np.ndarray, - y: np.ndarray, - use_log_space: bool, - delta: float, - reduction: Callable[[np.ndarray], float] | None = np.sum, -) -> float: - """ - Huber loss for the power-law model. - Args: - params: List of parameters [A, B, alpha, beta, E] - N: Number of parameters - D: Number of tokens - y: Actual loss - use_log_space: if true, residual is set to difference of logs of actual and predicted values - delta: huber loss delta, indicating the quadratic vs. linear loss changepoint. - reduction: Optional argument to change the reduction used on the Huber loss, defaults to sum based on https://arxiv.org/pdf/2404.10102v2 - """ - predictions = power_law_model(params, N, D, use_log_space) - if use_log_space: - residuals = np.log(y) - np.log(predictions) - else: - residuals = y - predictions - return reduction(huber(delta, residuals)) - - -def fit_power_law( - N: np.ndarray, - D: np.ndarray, - y: np.ndarray, - use_log_space: bool = False, - initial_guess: Sequence[float] | None = None, - delta: float = 1e-3, -) -> np.ndarray | tuple[float, float, float, float, float]: - """ - Fit a power law model to the data ((N, D), y). - - Args: - N: Number of parameters - D: Number of tokens - y: Actual loss or metric we want to learn to predict - use_log_space: if true, A and B are in log space *AND* Huber loss is computed in log space. - initial_guess: Initial guess for the parameters - delta: huber loss delta, indicating the quadratic vs. linear loss changepoint. - """ - # Compute the minimum y value to use as the initial guess for E - min_y = np.min(y) - - if use_log_space: - if initial_guess is None: - # Initialize E to max(min_y, 1e-10) to ensure it's positive - initial_guess = [0.0, 0.0, 1.0, 1.0, max(min_y, 1e-10)] # [log_A, log_B, alpha, beta, E] - bounds = [ - (None, None), # log_A unbounded - (None, None), # log_B unbounded - (0, None), # alpha >= 0 - (0, None), # beta >= 0 - (0, None), # E >= 0 - ] - else: - if initial_guess is None: - # Initialize E to max(min_y, 1e-10) to ensure it's positive - initial_guess = [1.0, 1.0, 1.0, 1.0, max(min_y, 1e-10)] # [A, B, alpha, beta, E] - bounds = [ - (0, None), # A >= 0 - (0, None), # B >= 0 - (0, None), # alpha >= 0 - (0, None), # beta >= 0 - (1e-10, None), # E >= 1e-10 to ensure E is positive - ] - - def objective(params): - return power_law_loss(params, N, D, y, use_log_space, delta) - - result = minimize( - objective, - initial_guess, - method="L-BFGS-B", - bounds=bounds, - options={"ftol": OPTIMIZATION_TOLERANCE, "gtol": OPTIMIZATION_TOLERANCE, "maxiter": 2500}, - ) - - if not result.success: - raise RuntimeError(f"Optimization failed: {result.message}") - - # return the fitted parameters, converting log_A and log_B back to A and B if needed - if use_log_space: - log_A, log_B, alpha, beta, E = result.x - A, B = np.exp(log_A), np.exp(log_B) - return A, B, alpha, beta, E - else: - return result.x - - -def predict_power_law(params: Sequence[float], N: np.ndarray, D: np.ndarray) -> np.ndarray: - A, B, alpha, beta, E = params - return A / (N**alpha) + B / (D**beta) + E - - -#################################################################################################### -# Sigmoidal fit helpers - - -def fit_sigmoidal(L: np.ndarray, y: np.ndarray, initial_guess: Sequence[float] | None = None) -> np.ndarray: - """ - Fit a sigmoidal model to the data (L, y). - - Equation: a / (1 + exp(-k * (L - L_0))) + b - - Args: - L: Task loss (input array) - y: Ground-truth task accuracy (output array) - initial_guess: Initial guess for [a, b, k, L_0], defaults to data-driven values - - Returns: - popt: Optimized parameters [a, b, k, L_0] - """ - # Set initial guess if not provided - if initial_guess is None: - y_min, y_max = np.min(y), np.max(y) - a_init = y_max - y_min # amplitude - b_init = y_min # offset - k_init = -1.0 # slope (negative for decreasing sigmoid) - L_0_init = np.mean(L) # midpoint - initial_guess = [a_init, b_init, k_init, L_0_init] - - # Set parameter bounds - lower_bounds = [0, 0, -np.inf, -np.inf] # a > 0, b >= 0, k unbounded below, L_0 unbounded - upper_bounds = [np.inf, np.inf, 0, np.inf] # a unbounded above, b unbounded, k <= 0, L_0 unbounded - bounds = (lower_bounds, upper_bounds) - - def objective(L, a, b, k, L_0): - return predict_sigmoidal([a, b, k, L_0], L) - - # Fit the model using scipy's curve_fit() - # https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.curve_fit.html - popt, _ = curve_fit( - objective, L, y, p0=initial_guess, bounds=bounds, maxfev=15000, method="trf", ftol=OPTIMIZATION_TOLERANCE - ) - - return popt - - -def predict_sigmoidal(params: Sequence[float], L: np.ndarray) -> np.ndarray: - a, b, k, L_0 = params - return a / (1 + np.exp(-k * (L - L_0))) + b - - -#################################################################################################### -# WandB and data processing helpers - - -def pull_metrics_from_wandb( - runs: Sequence[str], - metrics: Sequence[str], - entity: str, - project: str, - summary_fields: Sequence[str] = ("parameter_count",), -) -> pd.DataFrame: - """ - Pulls the metrics from the given runs and returns a DataFrame. - - Args: - runs: List of run IDs - metrics: List of metrics to pull from the runs; these differ depending on the step (unlike summary_fields) - entity: WandB entity - project: WandB project - summary_fields: List of summary fields to pull from the runs - - Returns: - Pandas dataFrame with the metrics - """ - - import wandb - - api = wandb.Api() - - data = [] - for run_id in runs: - run = api.run(f"{entity}/{project}/{run_id}") - run_data = {"run": run.name} - - # Get model configuration - model_dict = run.train_config.get("model", {}) - - run_data["hidden_dim"] = model_dict.get("hidden_dim", 0) - - # get the summary fields - for field in summary_fields: - run_data[field] = run.summary.get(field, None) - - # get the per-step metrics - history = run.history(keys=metrics) - - for i in range(len(history)): - step_data = {m: history[m][i] for m in metrics} - step_data.update(run_data) - step_data["step"] = i - data.append(step_data) - - return pd.DataFrame(data) - - -def filter_zero_d(df: pd.DataFrame, d_key: str = "throughput/total_tokens") -> pd.DataFrame: - """ - Returns a new DataFrame that excludes any rows where the specified - 'd_key' column is zero. - """ - return df[df[d_key] != 0].copy() - - -def extract_scaling_data( - df: pd.DataFrame, - param_count_col: str = "parameter_count", - tokens_col: str = "throughput/total_tokens", - loss_col: str | None = None, - count_embedding_params: bool = False, -) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - Extracts N, D, and y from the given DataFrame. - - Args: - df: DataFrame - param_count_col: Column name for the parameter count - tokens_col: Column name for the tokens - loss_col: Column name for the loss - - Returns: - Tuple of numpy arrays: (N, D, y) where - N = Number of parameters (excluding embedding parameters) - D = Number of tokens - y = Loss - """ - - N = df[param_count_col].values - D = df[tokens_col].values - y = df[loss_col].values if loss_col is not None else None - - # Apply non_embedding_params element-wise - if not count_embedding_params: - N = np.array([non_embedding_params(n, h) for n, h in zip(N, df["hidden_dim"].values, strict=False)]) - - return N, D, y - - -def aggregate_steps( - df: pd.DataFrame, - step_mode: str = "all", - step_range: tuple[int, int] = (1, 5), - group_col: str = "run", -) -> pd.DataFrame: - """ - Aggregates the steps for each run. - - Args: - df: DataFrame - step_mode: how to aggregate the steps - step_range: range of steps to aggregate - group_col: column to group by - - step_mode can be: - - "average": average step_range across each run - - "last": pick the max step within step_range - - "all": keep every step (no grouping) - """ - - if step_mode == "average": - grouped = df.groupby(group_col, as_index=False).mean(numeric_only=True) - return grouped - elif step_mode == "last": - # pick the largest step in the range for each run - def pick_last(g): - last_step_idx = g["step"].idxmax() - return g.loc[last_step_idx] - - grouped = df.groupby(group_col, as_index=False).apply(pick_last) - return grouped.reset_index(drop=True) - elif step_mode == "all": - # no aggregation - return df.copy() - else: - raise ValueError(f"Unknown step_mode: {step_mode}") - - -def non_embedding_params(total_param_count: int, hidden_dim: int, vocab_size: int = llama3_tokenizer_vocab_size): - return total_param_count - 2 * hidden_dim * vocab_size - - -#################################################################################################### -# Projection-specific helpers - - -@dataclass -class ProjectionPoint: - """A point to project to, consisting of number of parameters and tokens""" - - num_params: int - num_tokens: int - - -def get_default_projection_points(count_embedding_params: bool = False) -> list[ProjectionPoint]: - """Default set of model sizes to project to - - Args: - count_embedding_params: Whether to include embedding parameters in parameter count - """ - from experiments.llama import llama_1_4b, llama_8b, llama_13b, llama_24b, llama_70b - - # Base model configs - model_configs = [ - llama_1_4b, - llama_8b, - llama_13b, - llama_24b, - llama_70b, - ] - - # Token multipliers (relative to parameter count - token_multipliers = [0.5, 1, 5, 10, 20, 30, 50, 100] - - points = [] - for config in model_configs: - # Calculate total parameters - total_params = compute_num_parameters(config, llama3_tokenizer_vocab_size) - - # Adjust if we're not counting embedding params - if not count_embedding_params: - total_params = non_embedding_params(total_params, config.hidden_dim) - - # Create points with different token counts - for multiplier in token_multipliers: - num_tokens = int(total_params * multiplier) - points.append(ProjectionPoint(total_params, num_tokens)) - - return points - - -#################################################################################################### -# Plotting helpers - - -def plot_actual_vs_predicted( - y_actual: np.ndarray, - y_predicted: np.ndarray, - title: str = "Actual vs Predicted", - task_metric: str = "eval/paloma/c4_en/bpb", - tokens: np.ndarray | None = None, -) -> None: - """ - Plot actual vs predicted values. task_metric is the name of the metric we are predicting. - """ - plt.figure(figsize=(10, 6)) - - x_values = tokens if tokens is not None else np.arange(len(y_actual)) - - # plot actual and predicted values - plt.plot(x_values, y_actual, label="Actual", marker="o", linestyle="-", linewidth=2) - plt.plot(x_values, y_predicted, label="Predicted", marker="x", linestyle="--", linewidth=2) - - # add labels, legend, and title - plt.xlabel("Tokens" if tokens is not None else "Step") - plt.ylabel("Metric: " + task_metric) - plt.title(title) - plt.legend() - plt.grid(True) - - # Format tick labels to show B/T for billions/trillions - if tokens is not None: - - def format_ticks(x, _): - if x >= 1e12: - return f"{x/1e12:.1f}T" - elif x >= 1e9: - return f"{x/1e9:.1f}B" - else: - return f"{x/1e6:.1f}M" - - plt.gca().xaxis.set_major_formatter(plt.FuncFormatter(format_ticks)) - - return plt - - -def plot_scaling_projections(predicted: np.ndarray, points: list[ProjectionPoint] | None = None): - """ - Plot scaling law predictions vs tokens for specified model sizes. - - Args: - predicted: Array of predicted values - points: List of ProjectionPoint objects containing parameter and token counts - - Returns: - matplotlib.pyplot figure object - """ - plt.figure(figsize=(12, 6)) - unique_params = np.unique([p.num_params for p in points]) - - for param in unique_params: - mask = np.array([p.num_params == param for p in points]) - tokens = np.array([p.num_tokens for p in points])[mask] - preds = predicted[mask] - plt.plot(tokens, preds, "o-", linewidth=2, label=f"{param/1e9:.1f}B params") - - # add annotations for each point - for t, pred in zip(tokens, preds, strict=False): - token_str = f"{t/1e9:.1f}B" if t < 1e11 else f"{t/1e12:.1f}T" - plt.annotate(f"{token_str}, {pred:.3f}", (t, pred), ha="center", va="bottom", fontsize=6) - - plt.xscale("log") - plt.xlabel("Number of Tokens") - plt.ylabel("Predicted Loss") - plt.grid(True) - plt.legend() - return plt - - -#################################################################################################### -# Functions for fitting scaling laws - - -def fit_scaling_laws( - runs: list[str], - loss_metrics: Sequence[str], - accuracy_metrics: Sequence[str] | None, - entity: str, - project: str, - pred_run: str | None = None, - projection_points: list[ProjectionPoint] | None = None, - aggregation: str = "all", - tokens_col: str = "throughput/total_tokens", - param_col: str = "parameter_count", - count_embedding_params: bool = False, - use_log_for_ND: bool = False, - normalize_ND: bool = False, -) -> tuple[dict[str, np.ndarray], tuple[dict, dict, np.ndarray, np.ndarray] | None]: - """Fit scaling laws for both projection and prediction - - Args: - runs: list of run IDs to fit scaling laws for - loss_metrics: list of loss metrics to fit scaling laws for - accuracy_metrics: list of accuracy metrics to fit scaling laws for - entity: WandB entity - project: WandB project - pred_run: run ID to predict scaling laws for- if None, no prediction is done - projection_points: list of ProjectionPoint objects to project to - aggregation: how to aggregate steps within each run (all/last/average) - tokens_col: column name for the number of tokens - param_col: column name for the number of parameters - count_embedding_params: whether to count embedding parameters in calculating N - use_log_for_ND: whether to use log space for N and D - normalize_ND: whether to normalize N and D - - Returns: - tuple of: - - dict of loss metrics and their predictions - - dict of accuracy metrics and their predictions - - numpy array of tokens for x-axis of plots for losses - - numpy array of tokens for x-axis of plots for accuracies - """ - - # First pull for losses - only essential metrics - metrics = [*list(loss_metrics), tokens_col] - loss_df = pull_metrics_from_wandb( - runs=runs, - metrics=metrics, - entity=entity, - project=project, - summary_fields=(param_col,), - ) - - # Process loss data- remove 0-token runs, apply aggregation to the ladder runs' checkpoints (if specified) - loss_df_filtered = filter_zero_d(loss_df, tokens_col) - loss_df_agg = aggregate_steps(loss_df_filtered, step_mode=aggregation) - - # Get N, D - N, D, _ = extract_scaling_data(loss_df_agg, param_col, tokens_col, count_embedding_params=count_embedding_params) - if use_log_for_ND: - N = np.log(N) - D = np.log(D) - if normalize_ND: - N_scale = np.mean(N) - D_scale = np.mean(D) - N = N / N_scale - D = D / D_scale - - # Handle projections - projections = {} - - if projection_points: - N_proj = np.array([point.num_params for point in projection_points]) - D_proj = np.array([point.num_tokens for point in projection_points]) - - if use_log_for_ND: - N_proj, D_proj = np.log(N_proj), np.log(D_proj) - if normalize_ND: - N_proj, D_proj = N_proj / N_scale, D_proj / D_scale - - for loss_metric in loss_metrics: - y = loss_df_agg[loss_metric].values - params = fit_power_law(N, D, y, use_log_space=True) - projections[loss_metric] = predict_power_law(params, N_proj, D_proj) - - predictions = None - if pred_run: - loss_pred_df = pull_metrics_from_wandb( - runs=[pred_run], - metrics=[*list(loss_metrics), tokens_col], - entity=entity, - project=project, - summary_fields=(param_col,), - ) - - loss_pred_filtered = filter_zero_d(loss_pred_df, tokens_col) - loss_pred_agg = aggregate_steps(loss_pred_filtered, step_mode=aggregation) - - N_pred, D_pred, _ = extract_scaling_data( - loss_pred_agg, param_col, tokens_col, count_embedding_params=count_embedding_params - ) - if use_log_for_ND: - N_pred = np.log(N_pred) - D_pred = np.log(D_pred) - if normalize_ND: - N_pred = N_pred / N_scale - D_pred = D_pred / D_scale - - # Fit losses - loss_results = {} - for loss_metric in loss_metrics: - y = loss_df_agg[loss_metric].values - params = fit_power_law(N, D, y, use_log_space=True) - actual_loss = loss_pred_agg[loss_metric].values - predicted_loss = predict_power_law(params, N_pred, D_pred) - loss_results[loss_metric] = (actual_loss, predicted_loss) - - # Second pull for accuracies - accuracy_results = {} - if accuracy_metrics: - acc_df = pull_metrics_from_wandb( - runs=runs, - metrics=[*list(accuracy_metrics), tokens_col], - entity=entity, - project=project, - summary_fields=(param_col,), - ) - acc_pred_df = pull_metrics_from_wandb( - runs=[pred_run], - metrics=[*list(accuracy_metrics), tokens_col], - entity=entity, - project=project, - summary_fields=(param_col,), - ) - - acc_df_filtered = filter_zero_d(acc_df, tokens_col) - acc_df_agg = aggregate_steps(acc_df_filtered, step_mode=aggregation) - acc_pred_filtered = filter_zero_d(acc_pred_df, tokens_col) - acc_pred_agg = aggregate_steps(acc_pred_filtered, step_mode=aggregation) - - # Fit accuracies - accuracy_results = {} - loss_metric, (actual_loss, predicted_loss) = next(iter(loss_results.items())) # use first loss - - # Merge loss and accuracy data on run and tokens - merged_df = pd.merge( - loss_df_agg[["run", tokens_col, loss_metric]], - acc_df_agg[["run", tokens_col, *accuracy_metrics]], - on=["run", tokens_col], - how="inner", - ) - - # Merge prediction data similarly - merged_pred_df = pd.merge( - loss_pred_agg[["run", tokens_col]], - acc_pred_agg[["run", tokens_col, *accuracy_metrics]], - on=["run", tokens_col], - how="inner", - ) - - for acc_metric in accuracy_metrics: - task_losses = merged_df[loss_metric].values - acc = merged_df[acc_metric].values - params = fit_sigmoidal(task_losses, acc) - - acc_pred_actual = merged_pred_df[acc_metric].values - # Get the corresponding predicted losses for these points - pred_indices = loss_pred_agg[tokens_col].isin(merged_pred_df[tokens_col]) - pred_task_losses = predicted_loss[pred_indices] - - acc_preds = predict_sigmoidal(params, pred_task_losses) - accuracy_results[f"{acc_metric}_from_{loss_metric}"] = (acc_pred_actual, acc_preds) - - # Get token counts for plotting - loss_tokens = loss_pred_agg[tokens_col].values - acc_tokens = merged_pred_df[tokens_col].values - - predictions = (loss_results, accuracy_results, loss_tokens, acc_tokens) - - return projections, predictions diff --git a/scripts/migrations/migrate_isoflop_wandb_runs.py b/scripts/migrations/migrate_isoflop_wandb_runs.py new file mode 100644 index 0000000000..35cab22a0a --- /dev/null +++ b/scripts/migrations/migrate_isoflop_wandb_runs.py @@ -0,0 +1,232 @@ +#!/usr/bin/env python3 +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Migrate WandB isoflop runs to match migrated checkpoint paths. + +After scripts/migrations/migrate_isoflop_checkpoints.sh strips the 6-char hash +suffix from checkpoint paths (e.g., 'isoflop-1e+19-d2048-nemo-abc123' becomes +'isoflop-1e+19-d2048-nemo'), this script copies the corresponding WandB runs +to have matching names without the hash suffix. + +This enables eval_metrics_reader.py to find WandB runs by checkpoint name +without needing complex override mappings. +""" + +import argparse +import logging +import re +import sys + +try: + import wandb +except ImportError: + print("Error: wandb package not installed. Install with: pip install wandb") + sys.exit(1) + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + + +def copy_wandb_run( + api: wandb.Api, + source_run: wandb.apis.public.Run, + new_name: str, + entity: str, + project: str, + dry_run: bool = True, +) -> bool: + """ + Copy a WandB run to a new run with a different name. + + Args: + api: WandB API instance + source_run: The source run to copy + new_name: The new name for the copied run + entity: WandB entity + project: WandB project + dry_run: If True, don't actually create the run + + Returns: + True if successful (or would be successful in dry run) + """ + if dry_run: + logger.info(f" [DRY RUN] Would copy {source_run.name} -> {new_name}") + return True + + try: + # Initialize a new run with the clean name + new_run = wandb.init( + entity=entity, + project=project, + name=new_name, + id=new_name, # Use name as ID to make it deterministic + resume="never", + config=dict(source_run.config), + tags=list(source_run.tags), + ) + + # Copy summary metrics + summary = dict(source_run.summary) + for key, value in summary.items(): + new_run.summary[key] = value + + logger.info(f" Created new run: {new_name}") + new_run.finish() + return True + + except Exception as e: + logger.error(f" Failed to copy run {source_run.name}: {e}") + return False + + +def migrate_isoflop_wandb_runs( + entity_project: str, + run_name_filter: str | None = None, + dry_run: bool = True, +) -> None: + """ + Migrate WandB isoflop runs by copying them without hash suffixes. + + Args: + entity_project: WandB entity/project (format: 'entity/project') + run_name_filter: Optional filter to only process specific runs + dry_run: If True, don't actually create runs + """ + if "/" not in entity_project: + raise ValueError(f"Invalid entity_project format: {entity_project}. Expected 'entity/project'") + + entity, project = entity_project.split("/", 1) + api = wandb.Api() + + logger.info(f"Querying WandB for isoflop runs in {entity_project}...") + + # Query for isoflop runs with hash suffixes + filters = { + "displayName": {"$regex": "isoflop"}, + "state": "finished", + } + + runs = api.runs(entity_project, filters=filters) + + migrated_count = 0 + skipped_count = 0 + error_count = 0 + + for run in runs: + display_name = run.displayName + + # Check if this run has a hash suffix + if not re.search(r"-[0-9a-fA-F]{6}$", display_name): + logger.debug(f"Skipping {display_name} (no hash suffix)") + skipped_count += 1 + continue + + # Strip the hash to get the clean name + clean_name = re.sub(r"-[0-9a-fA-F]{6}$", "", display_name) + + # Apply filter if specified + if run_name_filter and run_name_filter not in clean_name: + logger.debug(f"Skipping {display_name} (doesn't match filter)") + skipped_count += 1 + continue + + # Check if a run with the clean name already exists + try: + api.run(f"{entity_project}/{clean_name}") + logger.info(f"Skipping {display_name} -> {clean_name} (already exists)") + skipped_count += 1 + continue + except wandb.errors.CommError: + # Run doesn't exist, we can create it + pass + + logger.info(f"Processing: {display_name} -> {clean_name}") + + if copy_wandb_run(api, run, clean_name, entity, project, dry_run): + migrated_count += 1 + else: + error_count += 1 + + logger.info("\n" + "=" * 60) + logger.info("Migration Summary:") + logger.info(f" Migrated: {migrated_count}") + logger.info(f" Skipped: {skipped_count}") + logger.info(f" Errors: {error_count}") + + if dry_run: + logger.info("\nDry run complete. Run with --execute to perform the migration.") + + +def main(): + parser = argparse.ArgumentParser( + description="Migrate WandB isoflop runs to match migrated checkpoint paths", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Dry run (default) + python migrate_isoflop_wandb_runs.py marin-community/marin + + # Execute the migration + python migrate_isoflop_wandb_runs.py marin-community/marin --execute + + # Filter to specific runs + python migrate_isoflop_wandb_runs.py marin-community/marin --filter nemo --execute + """, + ) + + parser.add_argument( + "entity_project", + help="WandB entity/project (format: entity/project)", + ) + + parser.add_argument( + "--execute", + action="store_true", + help="Actually perform the migration (default is dry run)", + ) + + parser.add_argument( + "--filter", + help="Only process runs whose clean name contains this string", + ) + + parser.add_argument( + "--verbose", + action="store_true", + help="Enable verbose logging", + ) + + args = parser.parse_args() + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + try: + migrate_isoflop_wandb_runs( + entity_project=args.entity_project, + run_name_filter=args.filter, + dry_run=not args.execute, + ) + except Exception as e: + logger.error(f"Migration failed: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py new file mode 100644 index 0000000000..1ac4c29688 --- /dev/null +++ b/tests/test_scaling_laws.py @@ -0,0 +1,224 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the scaling_laws module. + +These tests focus on integration and behavioral validation, particularly +the snapshot test which ensures reproducibility of config generation. +""" + +import jax.numpy as jnp + +from marin.scaling_laws.isoflop_analysis import ( + DEFAULT_SEQ_LEN, + CandidateConfig, + fit_scaling_laws, + robust_quad_logx, +) + +# Import the concrete recipe and transform function from experiments +from experiments.isoflop_sweep import Marin2025Recipe, parse_isoflop_run_name, transform_levanter_metrics + +# --- Run name parsing tests --- + + +def test_parse_isoflop_run_name(): + """Test parsing isoflop run names extracts experiment names.""" + # New format: isoflop-{budget}-N{params}-B{batch}-{experiment_name} + assert parse_isoflop_run_name("isoflop-1e+18-N1e+08-B128-nemo-wider-depth-adapt") == "nemo-wider-depth-adapt" + assert parse_isoflop_run_name("isoflop-1e+18-N1e+08-B128-dclm-a1b2c3") == "dclm" # hash stripped + + # Legacy format: isoflop-{budget}-d{hidden}-L{layers}-B{batch}-{experiment_name} + assert parse_isoflop_run_name("isoflop-1e+18-d512-L8-B128-dclm-a1b2c3") == "dclm" + assert parse_isoflop_run_name("isoflop-1e+19-d2048-L16-B1024-nemo-wider-depth-adapt") == "nemo-wider-depth-adapt" + + # Invalid formats return None + assert parse_isoflop_run_name("not-a-valid-name") is None + assert parse_isoflop_run_name("") is None + + +# --- Candidate config tests --- + + +def test_candidate_configs_within_tolerance(): + """Test that generated configs achieve the target FLOP budget within tolerance.""" + recipe = Marin2025Recipe() + budget = 1e19 + flop_tolerance = 0.01 + seq_len = DEFAULT_SEQ_LEN + + # Generate candidates using the new API + for model_config in recipe.build_model_configs(budget, seq_len): + flops_per_token = model_config.flops_per_token(recipe.vocab_size, seq_len) + tokens = budget / (3 * flops_per_token) + candidate = recipe.build_candidate_config(model_config, tokens, budget, seq_len) + + if candidate is None: + continue + + # Compute training FLOPs inline: 3 * flops_per_token * batch * steps * seq_len + achieved = 3 * flops_per_token * candidate.batch_size * candidate.train_steps * seq_len + relative_error = abs(achieved - budget) / budget + assert relative_error <= flop_tolerance + + +# --- Curve fitting tests --- + + +def test_robust_quad_logx_fits_quadratic(): + """Test that robust_quad_logx recovers known coefficients from synthetic data.""" + x = jnp.array([1e9, 1e10, 1e11, 1e12]) + L = jnp.log10(x) + # y = 0.1 * L^2 - 2 * L + 20 + y = 0.1 * L**2 - 2 * L + 20 + + a, b, c = robust_quad_logx(x, y) + + assert abs(a - 0.1) < 0.01 + assert abs(b - (-2)) < 0.1 + assert abs(c - 20) < 0.5 + + +# --- Snapshot test for config generation --- + +# Snapshot of expected output for candidates_for_budget with budget=3e18 training FLOPs. +EXPECTED_ISOFLOP_CONFIGS_3E18 = [ + {"batch_size": 32, "train_steps": 32844, "flops_budget": 3e18}, + {"batch_size": 16, "train_steps": 46274, "flops_budget": 3e18}, + {"batch_size": 16, "train_steps": 33965, "flops_budget": 3e18}, + {"batch_size": 8, "train_steps": 48105, "flops_budget": 3e18}, + {"batch_size": 8, "train_steps": 37335, "flops_budget": 3e18}, +] + + +def test_candidates_for_budget_snapshot(): + """Snapshot test: verify candidates_for_budget produces expected configs. + + This ensures reproducibility of the config generation algorithm. + """ + recipe = Marin2025Recipe() + result = list(recipe.candidates_for_budget(budget=3e18)) + + assert len(result) == len(EXPECTED_ISOFLOP_CONFIGS_3E18) + + for i, (candidate, expected) in enumerate(zip(result, EXPECTED_ISOFLOP_CONFIGS_3E18, strict=True)): + assert isinstance(candidate, CandidateConfig) + assert candidate.batch_size == expected["batch_size"], f"Config {i}: batch_size mismatch" + assert candidate.train_steps == expected["train_steps"], f"Config {i}: train_steps mismatch" + assert candidate.flops_budget == expected["flops_budget"], f"Config {i}: flops_budget mismatch" + + +# --- End-to-end integration test --- + +# Sample tracker_metrics.jsonl data simulating real runs +SAMPLE_METRICS_DATA = [ + # 1e18 budget - 3 runs with U-shaped loss curve + { + "run_path": "gs://marin/checkpoints/isoflop-1e+18-d1024-L11-B8-nemo", + "config": {"model": {"hidden_dim": 1024, "num_layers": 11}, "trainer": {"train_batch_size": 8}}, + "summary": { + "throughput/total_tokens": 1e9, + "throughput/total_gflops": 1e9, + "eval/paloma/c4_en/bpb": 1.25, + "parameter_count": 4e8, + }, + }, + { + "run_path": "gs://marin/checkpoints/isoflop-1e+18-d768-L8-B16-nemo", + "config": {"model": {"hidden_dim": 768, "num_layers": 8}, "trainer": {"train_batch_size": 16}}, + "summary": { + "throughput/total_tokens": 2.5e9, + "throughput/total_gflops": 1e9, + "eval/paloma/c4_en/bpb": 1.12, + "parameter_count": 2.7e8, + }, + }, + { + "run_path": "gs://marin/checkpoints/isoflop-1e+18-d512-L6-B32-nemo", + "config": {"model": {"hidden_dim": 512, "num_layers": 6}, "trainer": {"train_batch_size": 32}}, + "summary": { + "throughput/total_tokens": 5e9, + "throughput/total_gflops": 1e9, + "eval/paloma/c4_en/bpb": 1.18, + "parameter_count": 1.5e8, + }, + }, + # 1e19 budget - 3 runs + { + "run_path": "gs://marin/checkpoints/isoflop-1e+19-d2048-L21-B16-nemo", + "config": {"model": {"hidden_dim": 2048, "num_layers": 21}, "trainer": {"train_batch_size": 16}}, + "summary": { + "throughput/total_tokens": 3e9, + "throughput/total_gflops": 1e10, + "eval/paloma/c4_en/bpb": 1.05, + "parameter_count": 1.8e9, + }, + }, + { + "run_path": "gs://marin/checkpoints/isoflop-1e+19-d1536-L16-B32-nemo", + "config": {"model": {"hidden_dim": 1536, "num_layers": 16}, "trainer": {"train_batch_size": 32}}, + "summary": { + "throughput/total_tokens": 8e9, + "throughput/total_gflops": 1e10, + "eval/paloma/c4_en/bpb": 0.98, + "parameter_count": 1e9, + }, + }, + { + "run_path": "gs://marin/checkpoints/isoflop-1e+19-d1024-L11-B64-nemo", + "config": {"model": {"hidden_dim": 1024, "num_layers": 11}, "trainer": {"train_batch_size": 64}}, + "summary": { + "throughput/total_tokens": 2e10, + "throughput/total_gflops": 1e10, + "eval/paloma/c4_en/bpb": 1.02, + "parameter_count": 4e8, + }, + }, +] + + +def test_end_to_end_analysis_pipeline(): + """Integration test: transform metrics and fit scaling laws end-to-end. + + Uses SAMPLE_METRICS_DATA (simulating real wandb metrics) to verify the full + pipeline: metrics transformation -> curve fitting -> scaling law extraction. + """ + from marin.scaling_laws import round_flops_to_bucket + + # Transform metrics using the Levanter transform function + records = transform_levanter_metrics(SAMPLE_METRICS_DATA, "eval/paloma/c4_en/bpb") + assert len(records) == 6 + + # Fit scaling laws + fit_result = fit_scaling_laws(records) + + # Should find two minima (one per budget: ~1e18 and ~1e19) + # FLOP values are bucketed by round_flops_to_bucket + assert len(fit_result.minima_records) == 2 + + # Get expected bucketed values + bucket_1e18 = round_flops_to_bucket(1e18) + bucket_1e19 = round_flops_to_bucket(1e19) + assert {rec.flops for rec in fit_result.minima_records} == {bucket_1e18, bucket_1e19} + + # Verify fitted minima are near expected optimal points + minima_by_flops = {rec.flops: rec for rec in fit_result.minima_records} + + # At ~1e18: raw data optimal at 2.5B tokens (loss=1.12) + assert abs(minima_by_flops[bucket_1e18].optimal_tokens - 2.6e9) < 0.2e9 + assert abs(minima_by_flops[bucket_1e18].loss_at_optimal - 1.12) < 0.01 + + # At ~1e19: raw data optimal at 8B tokens (loss=0.98) + assert abs(minima_by_flops[bucket_1e19].optimal_tokens - 8.8e9) < 0.2e9 + assert abs(minima_by_flops[bucket_1e19].loss_at_optimal - 0.98) < 0.01 diff --git a/uv.lock b/uv.lock index f043fde34a..3b542ecdb0 100644 --- a/uv.lock +++ b/uv.lock @@ -3145,6 +3145,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/17/c6d9dc31001a495cb3c52fa69b22a0d8812880cb853f7c0573e2a5edad82/jaxlib-0.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:659d894d93876e3675c2132d13c3d241f204b21172a58f928b96f654f603f6dc", size = 59323262, upload-time = "2025-10-15T23:10:46.607Z" }, ] +[[package]] +name = "jaxopt" +version = "0.8.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jax" }, + { name = "jaxlib" }, + { name = "numpy" }, + { name = "scipy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3a/da/ff7d7fbd13b8ed5e8458e80308d075fc649062b9f8676d3fc56f2dc99a82/jaxopt-0.8.5.tar.gz", hash = "sha256:2790bd68ef132b216c083a8bc7a2704eceb35a92c0fc0a1e652e79dfb1e9e9ab", size = 121709, upload-time = "2025-04-14T17:59:01.618Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/45/d8/55e0901103c93d57bab3b932294c216f0cbd49054187ce29f8f13808d530/jaxopt-0.8.5-py3-none-any.whl", hash = "sha256:ff221d1a86908ec759eb1e219ee1d12bf208a70707e961bf7401076fe7cf4d5e", size = 172434, upload-time = "2025-04-14T17:59:00.342Z" }, +] + [[package]] name = "jaxtyping" version = "0.3.5" @@ -4317,6 +4332,7 @@ dependencies = [ { name = "google-cloud-storage-transfer" }, { name = "haliax" }, { name = "jax" }, + { name = "jaxopt" }, { name = "levanter", extra = ["serve"] }, { name = "lm-eval" }, { name = "lxml", extra = ["html-clean"] }, @@ -4326,6 +4342,7 @@ dependencies = [ { name = "numpy" }, { name = "openai" }, { name = "pandas" }, + { name = "plotly" }, { name = "pyarrow" }, { name = "ray" }, { name = "regex" }, @@ -4403,6 +4420,7 @@ dev = [ { name = "mypy" }, { name = "openai-responses" }, { name = "pip" }, + { name = "plotly" }, { name = "pylatexenc" }, { name = "pymdown-extensions" }, { name = "pyrefly" }, @@ -4451,6 +4469,7 @@ metrics = [ test = [ { name = "openai-responses" }, { name = "pip" }, + { name = "plotly" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, @@ -4482,6 +4501,7 @@ requires-dist = [ { name = "jax", marker = "extra == 'cpu'", specifier = "==0.8.0" }, { name = "jax", extras = ["cuda12"], marker = "extra == 'gpu'", specifier = "==0.8.0" }, { name = "jax", extras = ["tpu"], marker = "extra == 'tpu'", specifier = "==0.8.0" }, + { name = "jaxopt", specifier = ">=0.8.3" }, { name = "levanter", extras = ["serve"], editable = "lib/levanter" }, { name = "lm-eval", git = "https://github.com/stanford-crfm/lm-evaluation-harness?rev=d5e3391f22cde186c827674d5c3ec7c5f4fe0cab" }, { name = "lm-eval", extras = ["math"], marker = "extra == 'eval'", git = "https://github.com/stanford-crfm/lm-evaluation-harness?rev=d5e3391f22cde186c827674d5c3ec7c5f4fe0cab" }, @@ -4492,6 +4512,7 @@ requires-dist = [ { name = "numpy" }, { name = "openai" }, { name = "pandas" }, + { name = "plotly" }, { name = "prime", marker = "extra == 'rl'" }, { name = "pyarrow", specifier = ">=22" }, { name = "pylatexenc", marker = "extra == 'math'" }, @@ -4540,6 +4561,7 @@ dev = [ { name = "mypy", specifier = ">=1.4.1" }, { name = "openai-responses" }, { name = "pip" }, + { name = "plotly" }, { name = "pylatexenc" }, { name = "pymdown-extensions", specifier = ">=10.0.0" }, { name = "pyrefly", specifier = "==0.40.0" }, @@ -4586,6 +4608,7 @@ metrics = [{ name = "google-cloud-logging" }] test = [ { name = "openai-responses" }, { name = "pip" }, + { name = "plotly" }, { name = "pytest", specifier = ">=8.3.2" }, { name = "pytest-asyncio" }, { name = "pytest-cov" },