From 45e0937f15c6ad9e5fed39dd1e3af89ad89cf6f4 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 8 Dec 2025 15:07:28 -0500 Subject: [PATCH 01/79] Plotting Code --- .../src/marin/scaling_laws/isoflop_plot.py | 327 ++++++++++++++++++ 1 file changed, 327 insertions(+) create mode 100644 lib/marin/src/marin/scaling_laws/isoflop_plot.py diff --git a/lib/marin/src/marin/scaling_laws/isoflop_plot.py b/lib/marin/src/marin/scaling_laws/isoflop_plot.py new file mode 100644 index 0000000000..0a60891c9e --- /dev/null +++ b/lib/marin/src/marin/scaling_laws/isoflop_plot.py @@ -0,0 +1,327 @@ +import wandb +import plotly.io as pio +import jax.numpy as jnp +from jaxopt import ScipyMinimize +import pandas as pd +import plotly.graph_objects as go + +# ---------------- Theme ---------------- +pio.templates.default = "plotly_white" + +# ---------------- Constants ---------------- +PALETTE = [ + "#1877F2", + "#F0701A", + "#5A24C7", + "#E42C97", + "#00487C", + "#0EAC96", + "#AB76FF", + "#B50550", + "#0099E6", + "#22085F", + "#783301", +] +MARKERS = [ + "circle", + "square", + "cross", + "x", + "triangle-up", + "triangle-down", + "triangle-left", + "triangle-right", + "pentagon", + "hexagon", + "hexagon2", + "star", + "star-triangle-up", + "star-triangle-down", + "star-square", + "star-diamond", + "hourglass", + "bowtie", +] +DASHES = ["solid", "dot", "dash", "longdash", "dashdot", "longdashdot"] +DEFAULT_METRIC_KEY = "eval/paloma/c4_en/bpb" +SEQ_LEN = 4096 + +_MIN_MARKER = dict(symbol="diamond", size=10, color="#000000") +_SCALE_MARKER = dict(symbol="circle", size=9, color=PALETTE[0]) +_SCALE_LINE = dict(dash="dot", width=2, color=PALETTE[0]) + +REQUIRED_TAGS = {"steps", "B", "FLOPs", "d", "L"} +CANON_LABELS = ["nemo", "comma", "dclm"] # canonical dataset names we detect in displayName + + +# ---------------- Helpers ---------------- +def _tags_to_dict(tags): + return {k: v for k, v in (t.split("=", 1) for t in tags if "=" in t)} + + +def df_from_sources(source_runs: list[tuple[list, str]], metric_key: str = DEFAULT_METRIC_KEY) -> pd.DataFrame: + """Build a dataframe from [(runs, fragment), ...] and compute a 'label' per row.""" + records = [] + for runs, fragment in source_runs: + for run in runs: + summary = run.summary + tags = _tags_to_dict(run.tags) + if not REQUIRED_TAGS.issubset(tags): + continue + + steps = float(tags["steps"]) + batch = float(tags["B"]) + flops = float(tags["FLOPs"]) + if flops < 1e18: + continue + + tokens = steps * batch * SEQ_LEN + loss = summary.get(metric_key) + if loss is None: + continue + + params = summary.get("parameter_count") + name = run.displayName + + records.append(dict(tokens=tokens, loss=loss, flops=flops, params=params, name=name, label=fragment)) + return pd.DataFrame.from_records(records) + + +def _robust_quad_logx(x: jnp.ndarray, y: jnp.ndarray, delta: float = 1.0) -> jnp.ndarray: + L = jnp.log10(x) + + def huber(residual): + abs_r = jnp.abs(residual) + quad = 0.5 * residual**2 + linear = delta * (abs_r - 0.5 * delta) + return jnp.where(abs_r <= delta, quad, linear) + + def objective(params): + a, b, c = params + pred = a * L**2 + b * L + c + residuals = y - pred + return jnp.sum(huber(residuals)) + + opt = ScipyMinimize(fun=objective, method="BFGS", value_and_grad=False) + init = jnp.array(jnp.polyfit(L, y, 2)) if len(L) >= 3 else jnp.array([0.0, *jnp.polyfit(L, y, 1)]) + return opt.run(init_params=init).params + + +def iso_plot_with_minima_df(df: pd.DataFrame): + """ + Expects df columns: tokens, loss, flops, params, name, label. + ISO plot: + - points: color by compute bucket (FLOPs), marker shape by dataset label + - dashed parabolas: per-(label, FLOPs) robust quadratic fits (restored) + - minima per (label, FLOPs): black diamonds + SCALING plot: + - one N* ~ A*C^alpha fit line per dataset (distinct color/dash) + - dataset minima as points in matching color + """ + if df is None or df.empty: + return go.Figure(), go.Figure() + + present = list(dict.fromkeys(df["label"].tolist())) + datasets = [lab for lab in CANON_LABELS if lab in present] + [lab for lab in present if lab not in CANON_LABELS] + + # Visual maps + buckets = sorted(df.flops.unique()) + bucket_color = {C: PALETTE[i % len(PALETTE)] for i, C in enumerate(buckets)} # ISO: color = compute bucket + ds_marker = {lab: MARKERS[i % len(MARKERS)] for i, lab in enumerate(datasets)} # ISO: shape = dataset + DS_COLORS = PALETTE + DASHES = ["solid", "dot", "dash", "longdash", "dashdot", "longdashdot"] + + fig_iso = go.Figure() + minima = [] # (label, C, N_star, loss) + + # ---- ISO: scatter, per-(label,C) parabola (RESTORED), and minima + for lab in datasets: + for C in buckets: + sub = df[(df.flops == C) & (df.label == lab)].sort_values("tokens") + if sub.empty: + continue + + # scatter + fig_iso.add_trace( + go.Scatter( + x=sub.tokens, + y=sub.loss, + mode="markers", + marker=dict(symbol=ds_marker[lab], color=bucket_color[C], size=8), + name=f"{lab}, {C:.2e} FLOPs", + legendgroup=f"{lab}, {C:.2e}", + hovertemplate=( + "C=%{text:.2e} FLOPs
tokens=%{x:.3e}
" + "loss=%{y:.4f}
params=%{customdata:.3e}" + ), + text=[C] * len(sub), + customdata=sub.params.values, + ) + ) + + # robust quadratic fit in log10(tokens) + a, b, c = _robust_quad_logx(jnp.array(sub.tokens.values), jnp.array(sub.loss.values)) + if a == 0: + continue + + # draw the parabola for this (lab, C) + Ls = jnp.linspace(jnp.log10(sub.tokens.min()), jnp.log10(sub.tokens.max()), 200) + fig_iso.add_trace( + go.Scatter( + x=10**Ls, + y=a * Ls**2 + b * Ls + c, + mode="lines", + line=dict(color=bucket_color[C], dash="dash", width=2), + showlegend=False, # avoid legend clutter + legendgroup=f"{lab}, {C:.2e}", + ) + ) + + # compute and draw minimum + L_opt = -b / (2 * a) + N_star = float(10**L_opt) + loss_opt = float(a * L_opt**2 + b * L_opt + c) + params_opt = sub.iloc[(sub.tokens - N_star).abs().argmin()].params + minima.append((lab, float(C), N_star, loss_opt)) + + fig_iso.add_trace( + go.Scatter( + x=[N_star], + y=[loss_opt], + mode="markers", + marker=_MIN_MARKER, + showlegend=False, + legendgroup=f"{lab}, {C:.2e}", + hovertemplate=( + "Compute-optimal
" + "C=%{text:.2e} FLOPs
tokens=%{x:.3e}
" + "loss=%{y:.4f}
params=%{customdata:.3e}" + ), + text=[C], + customdata=[params_opt], + ) + ) + + fig_iso.update_layout( + template="plotly_white", + xaxis_type="log", + xaxis_title="Tokens (log scale)", + yaxis_title="Bits Per Byte Validation", + title="Marin IsoFLOP Suite", + width=1000, + height=600, + ) + + # ---- SCALING: separate line per dataset + if not minima: + return fig_iso, go.Figure() + + fig_scale = go.Figure() + by_lab = {} + for lab, C, N_star, _ in minima: + by_lab.setdefault(lab, []).append((C, N_star)) + + for i, lab in enumerate(datasets): + pts = by_lab.get(lab, []) + if not pts: + continue + pts = sorted(pts) + Cs, Ns = zip(*pts, strict=False) + Cs = jnp.array(Cs) + Ns = jnp.array(Ns) + + color = DS_COLORS[i % len(DS_COLORS)] + dash = DASHES[i % len(DASHES)] + + # plot minima points + fig_scale.add_trace( + go.Scatter( + x=list(map(float, Cs)), + y=list(map(float, Ns)), + mode="markers", + marker=dict(symbol=_SCALE_MARKER["symbol"], size=_SCALE_MARKER["size"], color=color), + name=f"{lab} minima", + legendgroup=lab, + ) + ) + + if len(Cs) >= 2: + alpha, logA = jnp.polyfit(jnp.log10(Cs), jnp.log10(Ns), 1) + A = 10**logA + Cmin, Cmax = float(Cs.min()), float(Cs.max()) + C_fit = jnp.logspace(jnp.log10(Cmin) - 0.1, jnp.log10(Cmax) + 0.1, 400) + N_fit = A * (C_fit**alpha) + + fig_scale.add_trace( + go.Scatter( + x=list(map(float, C_fit)), + y=list(map(float, N_fit)), + mode="lines", + line=dict(color=color, dash=dash, width=_SCALE_LINE["width"]), + name=f"{lab} fit", + legendgroup=lab, + ) + ) + + fig_scale.update_layout( + template="plotly_white", + xaxis_type="log", + yaxis_type="log", + xaxis_title="Compute budget C (FLOPs, log)", + yaxis_title="Optimal tokens N* (log)", + title="Scaling fits per dataset", + ) + + return fig_iso, fig_scale + + +# ---------------- Main ---------------- +def main(sources: list[tuple[str, str]]): + """ + sources: list of (ENTITY/PROJECT, REGEX_FRAGMENT) with single fragments (no '|'). + We query with r'isoflop.*()' and infer dataset labels from displayName, + falling back to the fragment so nothing gets dropped. + """ + RUN_ID = "marin-scaling-suite-isoflop" + wandb.login() + run = wandb.init( + entity="marin-community", + project="marin-analysis", + job_type="isoflop-analysis", + id=RUN_ID, + resume="allow", + name="isoflop-analysis", + ) + + api = wandb.Api() + source_runs = [] + for entity_project, fragment in sources: + if "/" not in entity_project: + raise ValueError(f"Bad ENTITY/PROJECT: {entity_project}") + if not fragment: + raise ValueError("Empty regex fragment") + + regex = rf"isoflop.*({fragment}).*" + filters = {"displayName": {"$regex": regex}, "state": "finished"} + runs = api.runs(entity_project.strip(), filters=filters) + source_runs.append((runs, fragment.strip())) + + df = df_from_sources(source_runs) + fig_iso, fig_scaling = iso_plot_with_minima_df(df) + + wandb.log( + { + "isoFLOP_plot": wandb.Plotly(fig_iso), + "scaling_plot": wandb.Plotly(fig_scaling), + } + ) + run.finish() + + +if __name__ == "__main__": + SOURCES = [ + ("marin-community/marin", "nemo-wider-depth-adapt"), + ("marin-community/marin", "comma"), + ("stanford-mercury/marin", "dclm-default"), + ] + main(SOURCES) From 4076d481bfb5d9624288741a2a66759e48eabf45 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 8 Dec 2025 23:03:44 -0500 Subject: [PATCH 02/79] Run Evals in ISOFlop --- experiments/isoflop_sweep.py | 40 ++++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index ca68974aae..2de450fe30 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -31,6 +31,8 @@ from levanter.optim.config import OptimizerConfig from levanter.utils.flop_utils import lm_flops_per_token +from experiments.evals.evals import default_eval +from experiments.evals.task_configs import MMLU_5_SHOT, EvalTaskConfig from experiments.common_pile.tokenize_common_pile import comma_main_mixture from experiments.defaults import default_tokenize, default_train from experiments.llama import compute_num_parameters, llama3_tokenizer @@ -130,6 +132,7 @@ class IsoFlopSweepConfig: lr_constant: float = 0.33 min_hidden_pow: int = 9 max_hidden_pow: int = 12 + eval_tasks: tuple[EvalTaskConfig, ...] | None = (MMLU_5_SHOT,) base_optimizer_config: OptimizerConfig = dataclasses.field( default_factory=lambda: CautiousConfig( learning_rate=1.0, # Placeholder @@ -263,11 +266,21 @@ def candidate_configs(cfg: IsoFlopSweepConfig, budget: float): yield (hidden_size, intermediate_dim, num_layers, n_heads, n_kv_heads, batch_size, train_steps, lr, b2) -def generate_isoflop_steps(config: IsoFlopSweepConfig, experiment_name: str) -> list[ExecutorStep]: - """Generate executor steps for an ISOFlop sweep.""" +def generate_isoflop_steps( + config: IsoFlopSweepConfig, + experiment_name: str, +) -> tuple[list[ExecutorStep], list[tuple[float, int, int, int, int]]]: + """Generate executor steps for an ISOFlop sweep. + + Returns: + A tuple of: + - steps: Training and evaluation ExecutorSteps for the sweep. + - metadata: (budget, hidden_size, num_layers, batch_size, train_steps) for each training run. + """ - steps: list[ExecutorStep] = [] - metadata = [] + train_steps_list: list[ExecutorStep] = [] + eval_steps: list[ExecutorStep] = [] + metadata: list[tuple[float, int, int, int, int]] = [] vocab_size = get_vocab_size_for_tokenizer(config.tokenizer) for budget in config.budgets: @@ -310,7 +323,7 @@ def generate_isoflop_steps(config: IsoFlopSweepConfig, experiment_name: str) -> ) run_name = f"isoflop-{budget:.0e}-d{hidden_size}-L{num_layers}-B{batch_size}-{experiment_name}" - step = default_train( + train_step = default_train( name=run_name, tokenized=config.tokenized_dataset, model_config=model_cfg, @@ -332,9 +345,20 @@ def generate_isoflop_steps(config: IsoFlopSweepConfig, experiment_name: str) -> "isoflop", run_name, ) - steps.append(step.with_output_path(static_output_path)) - - return steps, metadata + train_step = train_step.with_output_path(static_output_path) + train_steps_list.append(train_step) + + # Evaluation on the latest checkpoint for each ISOFlop run. + if config.eval_tasks: + eval_step = default_eval( + train_step, + resource_config=train_cfg.resources, + evals=config.eval_tasks, + ) + eval_steps.append(eval_step) + + all_steps: list[ExecutorStep] = [*train_steps_list, *eval_steps] + return all_steps, metadata def generate_isoflop_sweep( From 61e66f44c9d7b8ff8b59c97a822949717002af37 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Tue, 16 Dec 2025 10:12:54 -0800 Subject: [PATCH 03/79] Write Out JSONL for eval metrics rather than only storing in WandB --- lib/levanter/src/levanter/eval.py | 21 +++++++++++++++++++++ lib/levanter/src/levanter/main/train_lm.py | 6 ++++++ 2 files changed, 27 insertions(+) diff --git a/lib/levanter/src/levanter/eval.py b/lib/levanter/src/levanter/eval.py index 6d7b77e6ca..4b9e08b618 100644 --- a/lib/levanter/src/levanter/eval.py +++ b/lib/levanter/src/levanter/eval.py @@ -3,12 +3,15 @@ import asyncio import dataclasses +import json import logging +import os import warnings from collections import defaultdict from typing import Callable, Mapping, Optional, Sequence, TypeVar import equinox as eqx +import fsspec import jax.numpy as jnp import jmp import numpy as np @@ -173,6 +176,7 @@ def cb_tagged_lm_evaluate( eval_ema: bool = True, prefix: str = "eval", mp: jmp.Policy = None, + checkpoint_path: Optional[str] = None, ) -> Callable[[StepInfo], None]: """ Evaluates multiple tagged datasets using a given evaluation function. @@ -196,6 +200,7 @@ def cb_tagged_lm_evaluate( prefix: The prefix to use for logging the losses eval_current: Whether to evaluate the model's current parameters eval_ema: Whether to evaluate the EMA model (or other model averaged model) + checkpoint_path: If provided, write eval metrics to a JSONL file in this directory """ evaluator = TaggedEvaluator( @@ -207,10 +212,12 @@ def cb_tagged_lm_evaluate( def eval_callback(step: StepInfo): step_count = step.step + metrics_to_write = {} if eval_current: log_dict = eval_model(evaluator, step.model, prefix=prefix) levanter.tracker.log(log_dict, step=step_count) + metrics_to_write.update(log_dict) if not eval_current and step.state.model_averaging is None: raise ValueError("Cannot evaluate EMA model without model averaging, but you only want to evaluate EMA") @@ -218,6 +225,20 @@ def eval_callback(step: StepInfo): if eval_ema and step.state.model_averaging is not None: log_dict = eval_model(evaluator, step.eval_model, prefix=_join_prefix(prefix, "ema")) levanter.tracker.log(log_dict, step=step_count) + metrics_to_write.update(log_dict) + + # Write metrics to file if checkpoint_path is provided + if checkpoint_path is not None and metrics_to_write: + metrics_file = os.path.join(checkpoint_path, "eval_metrics.jsonl") + fs, _, _ = fsspec.get_fs_token_paths(metrics_file) + fs.makedirs(checkpoint_path, exist_ok=True) + with fs.open(metrics_file, "a") as f: + record = {"step": int(step_count), **metrics_to_write} + # Include WandB run info for backfill/lookup + wandb_info = levanter.tracker.current_tracker_info() + if wandb_info: + record["_tracker"] = wandb_info + f.write(json.dumps(record) + "\n") return diff --git a/lib/levanter/src/levanter/main/train_lm.py b/lib/levanter/src/levanter/main/train_lm.py index 68f3228e7b..aa45fca792 100644 --- a/lib/levanter/src/levanter/main/train_lm.py +++ b/lib/levanter/src/levanter/main/train_lm.py @@ -199,6 +199,11 @@ def loss_function(model: LmHeadModel, example: LmExample, *, key=None): if len(tagged_eval_datasets) == 0: logger.warning("No evaluation datasets provided.") else: + # Write eval metrics to the same directory as checkpoints + checkpoint_path = None + if config.trainer.checkpointer is not None: + checkpoint_path = config.trainer.checkpointer.expanded_path(trainer.run_id) + cb = levanter.eval.cb_tagged_lm_evaluate( EvalBatch, tagged_eval_datasets, @@ -207,6 +212,7 @@ def loss_function(model: LmHeadModel, example: LmExample, *, key=None): compute_axis_mapping, max_eval_examples_per_ds, mp=config.trainer.mp, + checkpoint_path=checkpoint_path, ) trainer.add_hook(cb, every=config.trainer.steps_per_eval) From 58535f56022a07894241a23a4dc93f150b68a23b Mon Sep 17 00:00:00 2001 From: Helw150 Date: Tue, 16 Dec 2025 10:41:27 -0800 Subject: [PATCH 04/79] Infra for jobs which read eval metrics --- .../marin/scaling_laws/eval_metrics_reader.py | 238 ++++++++++++++++++ 1 file changed, 238 insertions(+) create mode 100644 lib/marin/src/marin/scaling_laws/eval_metrics_reader.py diff --git a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py new file mode 100644 index 0000000000..1c3fecd10c --- /dev/null +++ b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py @@ -0,0 +1,238 @@ +# Copyright 2025 Marin Authors +# SPDX-License-Identifier: Apache-2.0 +""" +Base infrastructure for eval metrics analysis. + +This module provides a base config and utilities for analysis jobs that +read eval_metrics.jsonl files from completed training runs. Specific +analyses (like IsoFlop) should subclass EvalMetricsAnalysisConfig. +""" + +import logging +import json +import os +from dataclasses import dataclass +from typing import Callable, Sequence + +import fsspec +import pandas as pd + +try: + import wandb + + WANDB_AVAILABLE = True +except ImportError: + WANDB_AVAILABLE = False + +logger = logging.getLogger(__name__) + +from marin.execution.executor import ExecutorStep, InputName, output_path_of, this_output_path + + +def extract_run_name_from_path(path: str) -> str: + """Extract run name (last component) from a checkpoint path. + + E.g., 'gs://bucket/checkpoints/my-run-abc123' -> 'my-run-abc123' + """ + return os.path.basename(path.rstrip("/")) + + +def _backfill_metrics_from_wandb( + checkpoint_path: str, + metrics_file: str, + entity_project: str, + wandb_run_id: str | None = None, +) -> bool: + """ + Backfill eval_metrics.jsonl from WandB for a training run. + + Args: + checkpoint_path: Path to the checkpoint directory + metrics_file: Full path to where eval_metrics.jsonl should be written + entity_project: WandB entity/project (format: 'entity/project') + wandb_run_id: If provided, use this WandB run ID instead of inferring from path + + Returns: + True if backfill succeeded, False otherwise + """ + if not WANDB_AVAILABLE: + logger.warning(f"wandb not available, cannot backfill metrics for {checkpoint_path}") + return False + + try: + run_id = wandb_run_id or extract_run_name_from_path(checkpoint_path) + logger.info(f"Attempting to backfill summary metrics for run_id: {run_id}") + + api = wandb.Api() + run = api.run(f"{entity_project}/{run_id}") + + # Get summary metrics only + summary = dict(run.summary) + + eval_metrics = {k: v for k, v in summary.items() if k.startswith("eval/")} + if not eval_metrics: + logger.warning(f"No eval summary metrics found in WandB for run {run_id}") + return False + record = { + "step": summary.get("_step", summary.get("trainer/global_step", 0)), + **eval_metrics, + } + + fs, _, _ = fsspec.get_fs_token_paths(metrics_file) + fs.makedirs(os.path.dirname(metrics_file), exist_ok=True) + + with fs.open(metrics_file, "w") as f: + f.write(json.dumps(record) + "\n") + + logger.info(f"Successfully backfilled summary metrics to {metrics_file}") + return True + + except Exception as e: + return False + + +@dataclass(frozen=True) +class EvalMetricsAnalysisConfig: + """Base config for analyses that read eval metrics from training runs. + + Subclass this to create specific analysis types (e.g., IsoFlopAnalysisConfig). + The training_runs field creates blocking dependencies on the training jobs. + """ + + training_runs: Sequence[str] + """List of training run output paths to read eval metrics from (blocks until complete).""" + + output_path: str + """Where to write analysis outputs.""" + + metrics_filename: str = "eval_metrics.jsonl" + """Name of the metrics file within each checkpoint directory.""" + + backfill_from_wandb: bool = True + """If True, backfill eval_metrics.jsonl from WandB for runs that completed before this feature.""" + + wandb_entity_project: str = "marin-community/marin" + """WandB entity/project to query for backfill (format: 'entity/project').""" + + wandb_run_overrides: dict[str, str] | None = None + """Manual mapping from checkpoint path (or run name) to WandB run ID. + + Use this when the checkpoint path doesn't match the WandB run ID. + Keys can be full paths or just the run name (basename of path). + Example: {"isoflop-1e+19-d2048-nemo": "isoflop-1e+19-d2048-nemo-abc123"} + """ + + +def read_metrics_dataframe(config: EvalMetricsAnalysisConfig) -> pd.DataFrame: + """ + Read eval metrics from training runs into a DataFrame. + + This is the shared utility that all analysis subtypes use to load metrics. + It handles reading JSONL files and optional WandB backfill. + + Args: + config: Analysis config with training_runs and backfill settings + + Returns: + DataFrame with columns: step, run_index, run_path, + all eval/* metrics + """ + all_records = [] + + for i, run_path in enumerate(config.training_runs): + metrics_file = os.path.join(run_path, config.metrics_filename) + + fs, _, _ = fsspec.get_fs_token_paths(metrics_file) + + if not fs.exists(metrics_file): + logger.info(f"{metrics_file} does not exist") + + if config.backfill_from_wandb: + logger.info("Attempting to backfill from WandB...") + + # Check manual overrides (by full path or run name) + wandb_run_id = None + if config.wandb_run_overrides: + run_name = extract_run_name_from_path(run_path) + wandb_run_id = config.wandb_run_overrides.get(run_path) + if wandb_run_id is None: + wandb_run_id = config.wandb_run_overrides.get(run_name) + if wandb_run_id: + logger.info(f"Using manual override: {wandb_run_id}") + + success = _backfill_metrics_from_wandb( + checkpoint_path=run_path, + metrics_file=metrics_file, + entity_project=config.wandb_entity_project, + wandb_run_id=wandb_run_id, + ) + if not success: + raise RuntimeError( + f"Backfill from WandB failed for run {i} (path={run_path}, metrics_file={metrics_file})" + ) + else: + raise RuntimeError( + f"Metrics file missing for run {i} (path={run_path}), and backfill_from_wandb is disabled" + ) + + with fs.open(metrics_file, "r") as f: + for line in f: + line = line.strip() + if not line: + continue + + record = json.loads(line) + record["run_index"] = i + record["run_path"] = run_path + all_records.append(record) + + if not all_records: + logger.warning("No eval metrics found in any training runs") + return pd.DataFrame() + + df = pd.DataFrame(all_records) + logger.info(f"Loaded {len(all_records)} evaluation records from {len(config.training_runs)} runs") + logger.info(f"Available columns: {list(df.columns)}") + return df + + +def create_analysis_step( + name: str, + training_runs: Sequence[ExecutorStep | InputName], + analysis_fn: Callable[[EvalMetricsAnalysisConfig], None], + config_class: type[EvalMetricsAnalysisConfig], + description: str | None = None, + **config_kwargs, +) -> ExecutorStep: + """ + Create an ExecutorStep for an eval metrics analysis. + + This is the factory for creating analysis steps. It: + - Converts training ExecutorSteps to blocking dependencies + - Creates the appropriate config subclass + - Returns an ExecutorStep that runs the analysis + + Args: + name: Name for this executor step + training_runs: Training run ExecutorSteps (creates blocking dependencies) + analysis_fn: The analysis function to run + config_class: The config class (EvalMetricsAnalysisConfig or subclass) + description: Optional description + **config_kwargs: Additional kwargs passed to config_class + + Returns: + ExecutorStep configured to run the analysis + """ + run_paths = [output_path_of(run) if isinstance(run, ExecutorStep) else run for run in training_runs] + + config = config_class( + training_runs=run_paths, + output_path=this_output_path(), + **config_kwargs, + ) + + return ExecutorStep( + name=name, + fn=analysis_fn, + config=config, + description=description or f"Analyze eval metrics from {len(training_runs)} training runs", + ) From c26760b1f4559ce2e3e05ed75a316aaf2e736b33 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Tue, 16 Dec 2025 11:50:23 -0800 Subject: [PATCH 05/79] Stash --- .../src/marin/scaling_laws/isoflop_plot.py | 285 +++++++++++++++++- 1 file changed, 281 insertions(+), 4 deletions(-) diff --git a/lib/marin/src/marin/scaling_laws/isoflop_plot.py b/lib/marin/src/marin/scaling_laws/isoflop_plot.py index 0a60891c9e..9b76e8003c 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_plot.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_plot.py @@ -1,9 +1,68 @@ -import wandb -import plotly.io as pio +import logging +import os +import re +from dataclasses import dataclass +from typing import Sequence + +import fsspec import jax.numpy as jnp -from jaxopt import ScipyMinimize import pandas as pd import plotly.graph_objects as go +import plotly.io as pio +from jaxopt import ScipyMinimize + +try: + import wandb + + WANDB_AVAILABLE = True +except ImportError: + WANDB_AVAILABLE = False + +from marin.execution.executor import ExecutorStep, InputName, output_path_of +from marin.scaling_laws.eval_metrics_reader import ( + EvalMetricsAnalysisConfig, + create_analysis_step, + extract_run_name_from_path, + read_metrics_dataframe, +) + + +logger = logging.getLogger(__name__) + + +def build_wandb_run_overrides(wandb_sources: list[tuple[str, str]]) -> dict[str, str]: + """ + Builds a mapping from clean run names to full WandB displayNames. + This is used to find WandB runs for backfill, even when checkpoint paths + use the new clean names but WandB displayNames have legacy hash suffixes. + """ + if not WANDB_AVAILABLE: + logger.warning("wandb not available, cannot build run overrides") + return {} + + api = wandb.Api() + overrides = {} # clean_name -> full_displayName + + for entity_project, fragment in wandb_sources: + if "/" not in entity_project: + raise ValueError(f"Bad ENTITY/PROJECT: {entity_project}") + + regex = rf"isoflop.*({fragment}).*" + filters = {"displayName": {"$regex": regex}, "state": "finished"} + try: + runs = api.runs(entity_project.strip(), filters=filters) + for run in runs: + display_name = run.displayName + # The key for the override map is the "clean" name, without hash + clean_name = re.sub(r"-[0-9a-fA-F]{6}$", "", display_name) + # The value is the full name, which is used as the run ID for backfill + overrides[clean_name] = display_name + except Exception as e: + logger.warning(f"Failed to query WandB for {entity_project}: {e}") + + logger.info(f"Built {len(overrides)} WandB run overrides") + return overrides + # ---------------- Theme ---------------- pio.templates.default = "plotly_white" @@ -59,6 +118,34 @@ def _tags_to_dict(tags): return {k: v for k, v in (t.split("=", 1) for t in tags if "=" in t)} +def _parse_isoflop_run_name(run_name: str) -> dict | None: + """Parse metadata from isoflop run name. + + Expected format: isoflop-{budget}-d{hidden}-L{layers}-B{batch}-{experiment_name} + Optionally with a trailing - which is ignored. + E.g., 'isoflop-1e+19-d2048-L16-B1024-nemo-wider-depth-adapt' + or 'isoflop-1e+19-d2048-L16-B1024-nemo-wider-depth-adapt-a1b2c3' + + Returns dict with: flops, d, L, B, experiment_name or None if parsing fails. + """ + # Strip optional - suffix + run_name = re.sub(r"-[0-9a-fA-F]{6}$", "", run_name) + + pattern = r"isoflop-([0-9.e+]+)-d(\d+)-L(\d+)-B(\d+)-(.+)" + match = re.match(pattern, run_name) + if not match: + return None + + flops_str, d, L, B, exp_name = match.groups() + return { + "flops": float(flops_str), + "d": int(d), + "L": int(L), + "B": int(B), + "experiment_name": exp_name, + } + + def df_from_sources(source_runs: list[tuple[list, str]], metric_key: str = DEFAULT_METRIC_KEY) -> pd.DataFrame: """Build a dataframe from [(runs, fragment), ...] and compute a 'label' per row.""" records = [] @@ -275,7 +362,197 @@ def iso_plot_with_minima_df(df: pd.DataFrame): return fig_iso, fig_scale -# ---------------- Main ---------------- +# ---------------- Executor Integration ---------------- +@dataclass(frozen=True) +class IsoFlopAnalysisConfig(EvalMetricsAnalysisConfig): + """Config for isoflop analysis - extends base eval metrics analysis. + + Inherits training_runs, output_path, and backfill settings from base. + Adds isoflop-specific parameters. + """ + + metric_key: str = DEFAULT_METRIC_KEY + """Which metric to use for loss (default: eval/paloma/c4_en/bpb).""" + + label_map: dict[str, str] | None = None + """Optional mapping from experiment_name -> display label.""" + + upload_to_wandb: bool = True + """Whether to upload plots to WandB.""" + + wandb_entity: str = "marin-community" + wandb_project: str = "marin-analysis" + wandb_run_name: str = "isoflop-analysis" + + +def _transform_metrics_for_isoflop( + df: pd.DataFrame, + metric_key: str, + label_map: dict[str, str] | None, +) -> pd.DataFrame: + """Transform raw metrics DataFrame into isoflop plotting format. + + Takes the generic metrics DataFrame from read_metrics_dataframe() and + transforms it into the format expected by iso_plot_with_minima_df(): + columns: tokens, loss, flops, params, name, label + """ + if df.empty: + return pd.DataFrame(columns=["tokens", "loss", "flops", "params", "name", "label"]) + + # Get final metrics for each run (max step) + final_metrics = df.loc[df.groupby("run_path")["step"].idxmax()].copy() + + records = [] + for _, row in final_metrics.iterrows(): + run_path = row["run_path"] + run_name = extract_run_name_from_path(run_path) + + # Parse metadata from run name + meta = _parse_isoflop_run_name(run_name) + if meta is None: + print(f"Warning: Could not parse metadata from run name: {run_name}") + continue + + flops = meta["flops"] + if flops < 1e18: + continue + + # Calculate tokens = steps * batch * seq_len + steps = row["step"] + batch = meta["B"] + tokens = steps * batch * SEQ_LEN + + # Get loss from the metric column + loss = row.get(metric_key) + if loss is None or pd.isna(loss): + print(f"Warning: Missing metric {metric_key} for run {run_name}") + continue + + params = row.get("parameter_count") + if params is None or pd.isna(params): + params = None + + # Determine label + exp_name = meta["experiment_name"] + if label_map and exp_name in label_map: + label = label_map[exp_name] + else: + label = exp_name + for canon in CANON_LABELS: + if canon in exp_name.lower(): + label = canon + break + + records.append( + dict( + tokens=tokens, + loss=loss, + flops=flops, + params=params, + name=run_name, + label=label, + ) + ) + + return pd.DataFrame.from_records(records) + + +def run_isoflop_analysis(config: IsoFlopAnalysisConfig) -> None: + """Run isoflop analysis from training runs. + + This is a subtype of eval metrics analysis that: + 1. Reads metrics using the base read_metrics_dataframe() + 2. Transforms them for isoflop plotting + 3. Generates and saves isoflop/scaling plots + """ + # Use inherited metrics reading from base + raw_df = read_metrics_dataframe(config) + + if raw_df.empty: + print("Warning: No eval metrics found") + return + + # Transform to isoflop format + df = _transform_metrics_for_isoflop(raw_df, config.metric_key, config.label_map) + + if df.empty: + print("Warning: No valid isoflop data after transformation") + return + + print(f"Transformed {len(df)} runs for isoflop analysis") + fig_iso, fig_scaling = iso_plot_with_minima_df(df) + + # Save plots locally + fs, _, _ = fsspec.get_fs_token_paths(config.output_path) + fs.makedirs(config.output_path, exist_ok=True) + + iso_path = os.path.join(config.output_path, "isoflop_plot.html") + scaling_path = os.path.join(config.output_path, "scaling_plot.html") + + with fs.open(iso_path, "w") as f: + f.write(fig_iso.to_html()) + print(f"Wrote isoflop plot to {iso_path}") + + with fs.open(scaling_path, "w") as f: + f.write(fig_scaling.to_html()) + print(f"Wrote scaling plot to {scaling_path}") + + # Optionally upload to WandB + if config.upload_to_wandb and WANDB_AVAILABLE: + wandb.login() + run = wandb.init( + entity=config.wandb_entity, + project=config.wandb_project, + job_type="isoflop-analysis", + name=config.wandb_run_name, + resume="allow", + ) + wandb.log( + { + "isoFLOP_plot": wandb.Plotly(fig_iso), + "scaling_plot": wandb.Plotly(fig_scaling), + } + ) + run.finish() + print("Uploaded plots to WandB") + + +def create_isoflop_analysis_step( + name: str, + training_runs: Sequence[ExecutorStep | InputName], + metric_key: str = DEFAULT_METRIC_KEY, + label_map: dict[str, str] | None = None, + upload_to_wandb: bool = True, + description: str | None = None, +) -> ExecutorStep: + """Create an ExecutorStep for isoflop analysis. + + This uses the base create_analysis_step() with IsoFlopAnalysisConfig. + + Args: + name: Name for this executor step + training_runs: Training run ExecutorSteps (creates blocking dependencies) + metric_key: Which metric to use for loss + label_map: Optional mapping from experiment_name -> display label + upload_to_wandb: Whether to upload plots to WandB + description: Optional description + + Returns: + ExecutorStep configured to run isoflop analysis + """ + return create_analysis_step( + name=name, + training_runs=training_runs, + analysis_fn=run_isoflop_analysis, + config_class=IsoFlopAnalysisConfig, + description=description or f"IsoFLOP analysis for {len(training_runs)} runs", + metric_key=metric_key, + label_map=label_map, + upload_to_wandb=upload_to_wandb, + ) + + +# ---------------- Main (Legacy WandB-based) ---------------- def main(sources: list[tuple[str, str]]): """ sources: list of (ENTITY/PROJECT, REGEX_FRAGMENT) with single fragments (no '|'). From 55d907439b59f3242146ad8d313a1c23d4c79234 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Tue, 16 Dec 2025 14:07:32 -0800 Subject: [PATCH 06/79] IsoFLOPS into ExecutorStep --- .../marin/scaling_laws/eval_metrics_reader.py | 96 +++++---- .../src/marin/scaling_laws/isoflop_plot.py | 187 ++++++++---------- 2 files changed, 128 insertions(+), 155 deletions(-) diff --git a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py index 1c3fecd10c..fa5f37e1f2 100644 --- a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py +++ b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py @@ -1,3 +1,17 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Copyright 2025 Marin Authors # SPDX-License-Identifier: Apache-2.0 """ @@ -12,7 +26,7 @@ import json import os from dataclasses import dataclass -from typing import Callable, Sequence +from collections.abc import Callable, Sequence import fsspec import pandas as pd @@ -41,7 +55,6 @@ def _backfill_metrics_from_wandb( checkpoint_path: str, metrics_file: str, entity_project: str, - wandb_run_id: str | None = None, ) -> bool: """ Backfill eval_metrics.jsonl from WandB for a training run. @@ -50,7 +63,6 @@ def _backfill_metrics_from_wandb( checkpoint_path: Path to the checkpoint directory metrics_file: Full path to where eval_metrics.jsonl should be written entity_project: WandB entity/project (format: 'entity/project') - wandb_run_id: If provided, use this WandB run ID instead of inferring from path Returns: True if backfill succeeded, False otherwise @@ -59,36 +71,37 @@ def _backfill_metrics_from_wandb( logger.warning(f"wandb not available, cannot backfill metrics for {checkpoint_path}") return False - try: - run_id = wandb_run_id or extract_run_name_from_path(checkpoint_path) - logger.info(f"Attempting to backfill summary metrics for run_id: {run_id}") + try: + run_id = extract_run_name_from_path(checkpoint_path) + logger.info(f"Attempting to backfill summary metrics for run_id: {run_id}") - api = wandb.Api() - run = api.run(f"{entity_project}/{run_id}") + api = wandb.Api() + run = api.run(f"{entity_project}/{run_id}") - # Get summary metrics only - summary = dict(run.summary) + # Get summary metrics only + summary = dict(run.summary) - eval_metrics = {k: v for k, v in summary.items() if k.startswith("eval/")} - if not eval_metrics: - logger.warning(f"No eval summary metrics found in WandB for run {run_id}") - return False - record = { - "step": summary.get("_step", summary.get("trainer/global_step", 0)), - **eval_metrics, - } + eval_metrics = {k: v for k, v in summary.items() if k.startswith("eval/")} + if not eval_metrics: + logger.warning(f"No eval summary metrics found in WandB for run {run_id}") + return False + record = { + "step": summary.get("_step", summary.get("trainer/global_step", 0)), + **eval_metrics, + } - fs, _, _ = fsspec.get_fs_token_paths(metrics_file) - fs.makedirs(os.path.dirname(metrics_file), exist_ok=True) + fs, _, _ = fsspec.get_fs_token_paths(metrics_file) + fs.makedirs(os.path.dirname(metrics_file), exist_ok=True) - with fs.open(metrics_file, "w") as f: - f.write(json.dumps(record) + "\n") + with fs.open(metrics_file, "w") as f: + f.write(json.dumps(record) + "\n") - logger.info(f"Successfully backfilled summary metrics to {metrics_file}") - return True + logger.info(f"Successfully backfilled summary metrics to {metrics_file}") + return True - except Exception as e: - return False + except Exception as e: + logger.warning(f"Failed to backfill metrics from WandB: {e}") + return False @dataclass(frozen=True) @@ -114,14 +127,6 @@ class EvalMetricsAnalysisConfig: wandb_entity_project: str = "marin-community/marin" """WandB entity/project to query for backfill (format: 'entity/project').""" - wandb_run_overrides: dict[str, str] | None = None - """Manual mapping from checkpoint path (or run name) to WandB run ID. - - Use this when the checkpoint path doesn't match the WandB run ID. - Keys can be full paths or just the run name (basename of path). - Example: {"isoflop-1e+19-d2048-nemo": "isoflop-1e+19-d2048-nemo-abc123"} - """ - def read_metrics_dataframe(config: EvalMetricsAnalysisConfig) -> pd.DataFrame: """ @@ -149,30 +154,19 @@ def read_metrics_dataframe(config: EvalMetricsAnalysisConfig) -> pd.DataFrame: if config.backfill_from_wandb: logger.info("Attempting to backfill from WandB...") - # Check manual overrides (by full path or run name) - wandb_run_id = None - if config.wandb_run_overrides: - run_name = extract_run_name_from_path(run_path) - wandb_run_id = config.wandb_run_overrides.get(run_path) - if wandb_run_id is None: - wandb_run_id = config.wandb_run_overrides.get(run_name) - if wandb_run_id: - logger.info(f"Using manual override: {wandb_run_id}") - success = _backfill_metrics_from_wandb( checkpoint_path=run_path, metrics_file=metrics_file, entity_project=config.wandb_entity_project, - wandb_run_id=wandb_run_id, ) - if not success: + if not success: + raise RuntimeError( + f"Backfill from WandB failed for run {i} (path={run_path}, metrics_file={metrics_file})" + ) + else: raise RuntimeError( - f"Backfill from WandB failed for run {i} (path={run_path}, metrics_file={metrics_file})" + f"Metrics file missing for run {i} (path={run_path}), and backfill_from_wandb is disabled" ) - else: - raise RuntimeError( - f"Metrics file missing for run {i} (path={run_path}), and backfill_from_wandb is disabled" - ) with fs.open(metrics_file, "r") as f: for line in f: diff --git a/lib/marin/src/marin/scaling_laws/isoflop_plot.py b/lib/marin/src/marin/scaling_laws/isoflop_plot.py index 9b76e8003c..30ce4c56db 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_plot.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_plot.py @@ -1,8 +1,22 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import logging import os import re from dataclasses import dataclass -from typing import Sequence +from collections.abc import Sequence import fsspec import jax.numpy as jnp @@ -26,42 +40,16 @@ read_metrics_dataframe, ) +try: + from experiments.isoflop_sweep import MARIN_SCALING_SUITES -logger = logging.getLogger(__name__) + ISOFLOP_SWEEP_AVAILABLE = True +except ImportError: + ISOFLOP_SWEEP_AVAILABLE = False + MARIN_SCALING_SUITES = {} -def build_wandb_run_overrides(wandb_sources: list[tuple[str, str]]) -> dict[str, str]: - """ - Builds a mapping from clean run names to full WandB displayNames. - This is used to find WandB runs for backfill, even when checkpoint paths - use the new clean names but WandB displayNames have legacy hash suffixes. - """ - if not WANDB_AVAILABLE: - logger.warning("wandb not available, cannot build run overrides") - return {} - - api = wandb.Api() - overrides = {} # clean_name -> full_displayName - - for entity_project, fragment in wandb_sources: - if "/" not in entity_project: - raise ValueError(f"Bad ENTITY/PROJECT: {entity_project}") - - regex = rf"isoflop.*({fragment}).*" - filters = {"displayName": {"$regex": regex}, "state": "finished"} - try: - runs = api.runs(entity_project.strip(), filters=filters) - for run in runs: - display_name = run.displayName - # The key for the override map is the "clean" name, without hash - clean_name = re.sub(r"-[0-9a-fA-F]{6}$", "", display_name) - # The value is the full name, which is used as the run ID for backfill - overrides[clean_name] = display_name - except Exception as e: - logger.warning(f"Failed to query WandB for {entity_project}: {e}") - - logger.info(f"Built {len(overrides)} WandB run overrides") - return overrides +logger = logging.getLogger(__name__) # ---------------- Theme ---------------- @@ -114,8 +102,6 @@ def build_wandb_run_overrides(wandb_sources: list[tuple[str, str]]) -> dict[str, # ---------------- Helpers ---------------- -def _tags_to_dict(tags): - return {k: v for k, v in (t.split("=", 1) for t in tags if "=" in t)} def _parse_isoflop_run_name(run_name: str) -> dict | None: @@ -146,34 +132,6 @@ def _parse_isoflop_run_name(run_name: str) -> dict | None: } -def df_from_sources(source_runs: list[tuple[list, str]], metric_key: str = DEFAULT_METRIC_KEY) -> pd.DataFrame: - """Build a dataframe from [(runs, fragment), ...] and compute a 'label' per row.""" - records = [] - for runs, fragment in source_runs: - for run in runs: - summary = run.summary - tags = _tags_to_dict(run.tags) - if not REQUIRED_TAGS.issubset(tags): - continue - - steps = float(tags["steps"]) - batch = float(tags["B"]) - flops = float(tags["FLOPs"]) - if flops < 1e18: - continue - - tokens = steps * batch * SEQ_LEN - loss = summary.get(metric_key) - if loss is None: - continue - - params = summary.get("parameter_count") - name = run.displayName - - records.append(dict(tokens=tokens, loss=loss, flops=flops, params=params, name=name, label=fragment)) - return pd.DataFrame.from_records(records) - - def _robust_quad_logx(x: jnp.ndarray, y: jnp.ndarray, delta: float = 1.0) -> jnp.ndarray: L = jnp.log10(x) @@ -552,53 +510,74 @@ def create_isoflop_analysis_step( ) -# ---------------- Main (Legacy WandB-based) ---------------- -def main(sources: list[tuple[str, str]]): +# ---------------- Main (using experiments/isoflop_sweep.py) ---------------- +def main_from_isoflop_sweep( + suite_names: list[str] | None = None, + metric_key: str = DEFAULT_METRIC_KEY, + upload_to_wandb: bool = True, +): """ - sources: list of (ENTITY/PROJECT, REGEX_FRAGMENT) with single fragments (no '|'). - We query with r'isoflop.*()' and infer dataset labels from displayName, - falling back to the fragment so nothing gets dropped. + Run isoflop analysis using training runs from experiments/isoflop_sweep.py. + + Args: + suite_names: Names of scaling suites from MARIN_SCALING_SUITES (default: all) + metric_key: Which metric to use for loss + upload_to_wandb: Whether to upload plots to WandB """ - RUN_ID = "marin-scaling-suite-isoflop" - wandb.login() - run = wandb.init( - entity="marin-community", - project="marin-analysis", - job_type="isoflop-analysis", - id=RUN_ID, - resume="allow", - name="isoflop-analysis", - ) + if not ISOFLOP_SWEEP_AVAILABLE: + raise RuntimeError( + "Cannot import from experiments.isoflop_sweep. " "Make sure the experiments module is in your Python path." + ) - api = wandb.Api() - source_runs = [] - for entity_project, fragment in sources: - if "/" not in entity_project: - raise ValueError(f"Bad ENTITY/PROJECT: {entity_project}") - if not fragment: - raise ValueError("Empty regex fragment") + if suite_names is None: + suite_names = list(MARIN_SCALING_SUITES.keys()) - regex = rf"isoflop.*({fragment}).*" - filters = {"displayName": {"$regex": regex}, "state": "finished"} - runs = api.runs(entity_project.strip(), filters=filters) - source_runs.append((runs, fragment.strip())) + # Collect all training runs from the specified suites + all_training_runs = [] + label_map = {} - df = df_from_sources(source_runs) - fig_iso, fig_scaling = iso_plot_with_minima_df(df) + for suite_name in suite_names: + if suite_name not in MARIN_SCALING_SUITES: + logger.warning(f"Suite '{suite_name}' not found in MARIN_SCALING_SUITES") + continue + + steps, _ = MARIN_SCALING_SUITES[suite_name] + # Filter to just training steps (not eval steps) + training_steps = [step for step in steps if step.name.startswith("isoflop-")] + all_training_runs.extend(training_steps) + + # Build label map from experiment names + for step in training_steps: + meta = _parse_isoflop_run_name(step.name) + if meta: + exp_name = meta["experiment_name"] + # Map experiment name to canonical label + for canon in CANON_LABELS: + if canon in exp_name.lower(): + label_map[exp_name] = canon + break + + if not all_training_runs: + logger.error("No training runs found in specified suites") + return - wandb.log( - { - "isoFLOP_plot": wandb.Plotly(fig_iso), - "scaling_plot": wandb.Plotly(fig_scaling), - } + logger.info(f"Found {len(all_training_runs)} training runs across {len(suite_names)} suites") + + # Create and run analysis + config = IsoFlopAnalysisConfig( + training_runs=[output_path_of(step) for step in all_training_runs], + output_path="analysis/isoflop", + metric_key=metric_key, + label_map=label_map, + upload_to_wandb=upload_to_wandb, ) - run.finish() + + run_isoflop_analysis(config) if __name__ == "__main__": - SOURCES = [ - ("marin-community/marin", "nemo-wider-depth-adapt"), - ("marin-community/marin", "comma"), - ("stanford-mercury/marin", "dclm-default"), - ] - main(SOURCES) + # Use the new logic that imports from isoflop_sweep.py + main_from_isoflop_sweep( + suite_names=["nemotron", "common_pile", "dclm-default"], + upload_to_wandb=True, + ) From e49bd99cdbae1baf7c22225cd05ea16f3b9ece09 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Wed, 17 Dec 2025 22:42:22 -0800 Subject: [PATCH 07/79] Lots of refactoring --- docs/references/default-steps.md | 6 - experiments/defaults.py | 39 - experiments/exp1752_simulated_epoching.py | 158 --- experiments/exp1752_stackv2_vs_starcoder.py | 92 -- .../exp2166_scaling_ladder_analysis.py | 80 ++ experiments/isoflop_sweep.py | 301 ++---- experiments/tootsie/exp654_scaling_tootsie.py | 149 --- lib/marin/pyproject.toml | 1 + lib/marin/src/marin/scaling_laws/__init__.py | 46 + .../marin/scaling_laws/create_ladder_suite.py | 139 --- .../marin/scaling_laws/isoflop_analysis.py | 972 ++++++++++++++++++ .../src/marin/scaling_laws/isoflop_plot.py | 583 ----------- .../src/marin/scaling_laws/scaling_ladder.py | 411 ++++++++ .../src/marin/scaling_laws/scaling_laws.py | 218 ---- .../src/marin/scaling_laws/scaling_plots.py | 336 ++++++ lib/marin/src/marin/scaling_laws/utils.py | 667 ------------ .../migrations/migrate_isoflop_wandb_runs.py | 219 ++++ uv.lock | 16 + 18 files changed, 2157 insertions(+), 2276 deletions(-) delete mode 100644 experiments/exp1752_simulated_epoching.py delete mode 100644 experiments/exp1752_stackv2_vs_starcoder.py create mode 100644 experiments/exp2166_scaling_ladder_analysis.py delete mode 100644 experiments/tootsie/exp654_scaling_tootsie.py delete mode 100644 lib/marin/src/marin/scaling_laws/create_ladder_suite.py create mode 100644 lib/marin/src/marin/scaling_laws/isoflop_analysis.py delete mode 100644 lib/marin/src/marin/scaling_laws/isoflop_plot.py create mode 100644 lib/marin/src/marin/scaling_laws/scaling_ladder.py delete mode 100644 lib/marin/src/marin/scaling_laws/scaling_laws.py create mode 100644 lib/marin/src/marin/scaling_laws/scaling_plots.py delete mode 100644 lib/marin/src/marin/scaling_laws/utils.py create mode 100644 scripts/migrations/migrate_isoflop_wandb_runs.py diff --git a/docs/references/default-steps.md b/docs/references/default-steps.md index 1f491067e1..628850c8d4 100644 --- a/docs/references/default-steps.md +++ b/docs/references/default-steps.md @@ -27,12 +27,6 @@ In general, you should reach for the default steps before writing your own. ::: experiments.defaults.simulated_epoching_train -## Scaling Law Prediction - -::: marin.scaling_laws.create_ladder_suite.scaling_law_suite - -::: experiments.defaults.default_scaling_law_pred - ## Evaluation ::: experiments.evals.evals.default_eval diff --git a/experiments/defaults.py b/experiments/defaults.py index 809f159d79..fa5b8ebb10 100644 --- a/experiments/defaults.py +++ b/experiments/defaults.py @@ -72,7 +72,6 @@ tokenize, ) from marin.processing.tokenize.tokenize import HfTokenizeConfig, TokenizeConfigBase -from marin.scaling_laws.scaling_laws import ScalingLawConfig, run_scaling_law_analysis from marin.training.training import ( TrainLmOnPodConfig, run_levanter_train_lm, @@ -637,41 +636,3 @@ def _get_tokenizer_for_train(tokenized: InputName | ExecutorStep | LMMixtureData raise ValueError(f"Could not determine tokenizer from {tokenized}") return tokenizer - - -def default_scaling_law_pred( - ladder_runs: Sequence[ExecutorStep | InputName | str], - pred_run: ExecutorStep | InputName | str | None = None, - task_losses: Sequence[str] = ("eval/paloma/c4_en/bpb",), - task_accuracies: Sequence[str] | Sequence[EvalTaskConfig] | None = None, -): - """ - Given a suite of small models, predict the performance on a number of (N, D) values. - """ - # get the executor steps or run IDs for the ladder runs and the pred run - ladder_steps_or_ids = [get_executor_step(run) if not isinstance(run, str) else run for run in ladder_runs] - - pred_run_or_id = None - if pred_run: - pred_run_or_id = get_executor_step(pred_run) if not isinstance(pred_run, str) else pred_run - - # convert the task accuracies to strings if they are `EvalTaskConfig`s - if task_accuracies is not None: - task_accuracies = convert_to_task_metrics(task_accuracies, metric="acc") - - if pred_run_or_id: - name = pred_run_or_id if isinstance(pred_run_or_id, str) else pred_run_or_id.name - else: - name = "projection" - - return ExecutorStep( - name=f"""scaling_laws/{name}""", - fn=run_scaling_law_analysis, - config=ScalingLawConfig( - name=name, - ladder_model_steps=ladder_steps_or_ids, - pred_model_step=pred_run_or_id, - task_losses=task_losses, - task_accuracies=task_accuracies, - ), - ) diff --git a/experiments/exp1752_simulated_epoching.py b/experiments/exp1752_simulated_epoching.py deleted file mode 100644 index a88bde0149..0000000000 --- a/experiments/exp1752_simulated_epoching.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright 2025 The Marin Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Scaling law comparison between Stack v2 datasets and StarCoderData with simulated epoching.""" - -import dataclasses -import logging -from collections.abc import Sequence - -from levanter.data.text import LMMixtureDatasetConfig -from levanter.models.llama import LlamaConfig - -from experiments.common_pile.tokenize_common_pile import stackv2, stackv2_edu_filtered -from experiments.defaults import default_tokenize, simulated_epoching_train -from experiments.evals.task_configs import CORE_TASKS -from experiments.llama import llama3_tokenizer, llama_1_4b -from experiments.pretraining_datasets.dclm import dclm_components_llama3 -from experiments.simple_train_config import SimpleTrainConfig -from fray.cluster import ResourceConfig -from marin.execution.executor import ExecutorStep, InputName, executor_main - -TPU_TYPE = "v5p-8" -TAG = ["exp1752", "simulated_epoching"] - -STACK_V2_SWEEP_NAME = "exp1752-stack-v2-sim" -STACK_V2_EDU_SWEEP_NAME = "exp1752-stack-v2-edu-sim" -STARCODER_SWEEP_NAME = "exp1752-starcoderdata-sim" - -SIMULATED_TARGET_BUDGET_TOKENS = 15_000_000_000_000 # 15T tokens to mimic full-budget epoching behaviour - -training_config = SimpleTrainConfig( - resources=ResourceConfig.with_tpu(TPU_TYPE, slice_count=1), - train_batch_size=256, - learning_rate=1e-3, - weight_decay=0.1, - num_train_steps=200000, - warmup=1000, - decay=0.0, - lr_schedule="constant", - ema_beta=0.995, - steps_per_eval=500, - steps_per_task_eval=500, -) - - -def simulated_scaling_law_suite( - sweep_name: str, - tokenized: InputName | ExecutorStep | LMMixtureDatasetConfig, - *, - widths: Sequence[int] = (512, 768, 1024, 1536, 2048), - base_model_config: LlamaConfig = llama_1_4b, - tags: Sequence[str] = (), - intermediate_scale: float = 4, - training_config: SimpleTrainConfig = training_config, - base_lr: float = 3e-4 * 4096, - max_lr: float = 5e-3, - target_budget: int = SIMULATED_TARGET_BUDGET_TOKENS, -) -> Sequence[ExecutorStep]: - """Mirror scaling_law_suite but replace training with simulated epoching.""" - - steps: list[ExecutorStep] = [] - for width in widths: - intermediate_dim = _round_to_multiple(intermediate_scale * width, 128) - head_size = 128 # keeping this 128 means we can use splash attention - num_heads = width // head_size - num_kv_heads = min(num_heads, 8) - assert num_heads * head_size == width, f"Number of heads must divide width: {width} % {head_size} != 0" - - if num_heads % num_kv_heads != 0: - num_kv_heads = num_heads - - model_config = dataclasses.replace( - base_model_config, - hidden_dim=width, - intermediate_dim=intermediate_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - ) - - learning_rate = min(base_lr / width, max_lr) - lr_training_config = dataclasses.replace(training_config, learning_rate=learning_rate) - - logging.info(f"Creating simulated epoching step for {sweep_name}-{width} with lr {learning_rate}") - - steps.append( - simulated_epoching_train( - name=f"{sweep_name}-{width}", - tokenized=tokenized, - model_config=model_config, - train_config=lr_training_config, - target_budget=target_budget, - tags=tags, - eval_harness_tasks=CORE_TASKS, - ) - ) - return steps - - -def _round_to_multiple(x: float, multiple: int) -> int: - return int(multiple * round(x / multiple)) - - -stackv2_tokenized = default_tokenize( - name="common_pile_stackv2", - dataset=stackv2 / "documents", - tokenizer=llama3_tokenizer, -) - -stackv2_edu_tokenized = default_tokenize( - name="common_pile_stackv2_edu", - dataset=stackv2_edu_filtered, - tokenizer=llama3_tokenizer, -) - -stackv2_suite = simulated_scaling_law_suite( - sweep_name=STACK_V2_SWEEP_NAME, - tokenized=stackv2_tokenized, - tags=[*TAG, "stackv2"], - intermediate_scale=4, - training_config=training_config, -) - -stackv2_edu_suite = simulated_scaling_law_suite( - sweep_name=STACK_V2_EDU_SWEEP_NAME, - tokenized=stackv2_edu_tokenized, - tags=[*TAG, "stackv2_edu"], - intermediate_scale=4, - training_config=training_config, -) - -starcoder_suite = simulated_scaling_law_suite( - sweep_name=STARCODER_SWEEP_NAME, - tokenized=dclm_components_llama3["starcoderdata"], - tags=[*TAG, "starcoderdata"], - intermediate_scale=4, - training_config=training_config, -) - -if __name__ == "__main__": - executor_main( - steps=[ - *stackv2_suite, - *stackv2_edu_suite, - *starcoder_suite, - ], - description="Scaling law sweeps comparing Stack v2 with StarCoderData using simulated epoching.", - ) diff --git a/experiments/exp1752_stackv2_vs_starcoder.py b/experiments/exp1752_stackv2_vs_starcoder.py deleted file mode 100644 index da8b6c3746..0000000000 --- a/experiments/exp1752_stackv2_vs_starcoder.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2025 The Marin Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Scaling law comparison between Stack v2 datasets and StarCoderData.""" - -from experiments.common_pile.tokenize_common_pile import stackv2, stackv2_edu_filtered -from experiments.defaults import default_tokenize -from experiments.llama import llama3_tokenizer -from experiments.pretraining_datasets.dclm import dclm_components_llama3 -from experiments.simple_train_config import SimpleTrainConfig -from fray.cluster import ResourceConfig -from marin.execution.executor import executor_main -from marin.scaling_laws.create_ladder_suite import scaling_law_suite - -TPU_TYPE = "v5p-8" -TAG = ["exp1752_stackv2_vs_starcoder"] - -STACK_V2_SWEEP_NAME = "exp1752-stack-v2" -STACK_V2_EDU_SWEEP_NAME = "exp1752-stack-v2-edu" -STARCODER_SWEEP_NAME = "exp1752-starcoderdata" - -training_config = SimpleTrainConfig( - resources=ResourceConfig.with_tpu(TPU_TYPE, slice_count=1), - train_batch_size=256, - learning_rate=1e-3, - weight_decay=0.1, - num_train_steps=200000, - warmup=1000, - decay=0.0, - lr_schedule="constant", - ema_beta=0.995, - steps_per_eval=500, - steps_per_task_eval=500, -) - - -stackv2_tokenized = default_tokenize( - name="common_pile_stackv2", - dataset=stackv2 / "documents", - tokenizer=llama3_tokenizer, -) - -stackv2_edu_tokenized = default_tokenize( - name="common_pile_stackv2_edu", - dataset=stackv2_edu_filtered, - tokenizer=llama3_tokenizer, -) - -stackv2_suite = scaling_law_suite( - sweep_name=STACK_V2_SWEEP_NAME, - tokenized=stackv2_tokenized, - tags=[*TAG, "stackv2"], - intermediate_scale=4, - training_config=training_config, -) - -stackv2_edu_suite = scaling_law_suite( - sweep_name=STACK_V2_EDU_SWEEP_NAME, - tokenized=stackv2_edu_tokenized, - tags=[*TAG, "stackv2_edu"], - intermediate_scale=4, - training_config=training_config, -) - -starcoder_suite = scaling_law_suite( - sweep_name=STARCODER_SWEEP_NAME, - tokenized=dclm_components_llama3["starcoderdata"], - tags=[*TAG, "starcoderdata"], - intermediate_scale=4, - training_config=training_config, -) - -if __name__ == "__main__": - executor_main( - steps=[ - *stackv2_suite, - *stackv2_edu_suite, - *starcoder_suite, - ], - description="Scaling law sweeps comparing Stack v2 with StarCoderData.", - ) diff --git a/experiments/exp2166_scaling_ladder_analysis.py b/experiments/exp2166_scaling_ladder_analysis.py new file mode 100644 index 0000000000..8f460572a4 --- /dev/null +++ b/experiments/exp2166_scaling_ladder_analysis.py @@ -0,0 +1,80 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Exp2166: IsoFLOP Analysis and Scaling Ladders for Nemotron, Comma, and Dolma3. + +This experiment runs IsoFLOP analysis on the isoflop training sweeps +for three datasets: +- Nemotron (nemo-wider-depth-adapt) +- Common Pile / Comma (comma-mix) +- Dolma3 (dolma3-mix-150b-1025) + +The IsoFLOP analysis fits scaling laws to find compute-optimal configurations and +generates visualization plots. It also demonstrates scaling ladder runs (compute-optimal +training runs) that use the predicted configurations. +""" + +from experiments.isoflop_sweep import MARIN_SCALING_SUITES, nemotron_mix, dolma3_mix +from marin.execution.executor import executor_main +from marin.scaling_laws import isoflop_analysis_step, scaling_ladder_suite + +# Get training steps for each dataset (eval_tasks=None by default, so only training steps) +nemotron_training, _ = MARIN_SCALING_SUITES["nemotron"] +comma_training, _ = MARIN_SCALING_SUITES["common_pile"] +dolma3_training, _ = MARIN_SCALING_SUITES["dolma3_mix_150b"] + + +# --- IsoFLOP analysis-only steps (no scaling ladder rungs) --- + +nemotron_analysis = isoflop_analysis_step( + name="exp2166-isoflop-analysis-nemotron", + training_runs=nemotron_training, + wandb_run_name="exp2166-isoflop-analysis-nemotron", +) + + +dolma3_analysis = isoflop_analysis_step( + name="exp2166-isoflop-analysis-dolma3", + training_runs=dolma3_training, + wandb_run_name="exp2166-isoflop-analysis-dolma3", +) + + +# --- Full scaling ladder suites --- +# These create IsoFLOP analysis + scaling ladder rungs (optimal training runs) for target budgets + +# Nemotron suite: analyze isoflop runs, then train optimal models at larger budgets +nemotron_suite = scaling_ladder_suite( + name="exp2166-nemo", + training_runs=nemotron_training, + target_budgets=[1e21, 3e21], + label="nemo", + dataset=nemotron_mix, +) + + +# Dolma3 suite +dolma3_suite = scaling_ladder_suite( + name="exp2166-dolma3", + training_runs=dolma3_training, + target_budgets=[1e21, 3e21], + label="dolma3", + dataset=dolma3_mix, +) + + +all_steps = [nemotron_analysis, dolma3_analysis, *nemotron_suite.all_steps, *dolma3_suite.all_steps] + +if __name__ == "__main__": + executor_main(steps=all_steps) diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index 2de450fe30..3c615e544a 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Generate ISOFlop sweep steps for varying model sizes on a target datasett. +"""Generate ISOFlop sweep steps for varying model sizes on a target dataset. This script constructs `ExecutorStep` objects that train models of different sizes while keeping the total training FLOPs roughly constant. It is intended @@ -20,7 +20,6 @@ """ import dataclasses -import math import os from dataclasses import dataclass, replace @@ -29,7 +28,6 @@ from levanter.models.qwen import Qwen3Config from levanter.optim.cautious import CautiousConfig from levanter.optim.config import OptimizerConfig -from levanter.utils.flop_utils import lm_flops_per_token from experiments.evals.evals import default_eval from experiments.evals.task_configs import MMLU_5_SHOT, EvalTaskConfig @@ -43,96 +41,29 @@ from fray.cluster import ResourceConfig from marin.execution.executor import ExecutorStep, InputName, executor_main from marin.processing.tokenize import lm_mixture_data_config - -DEFAULT_BUDGETS = [1e18, 3e18, 6e18, 1e19, 3e19, 6e19, 1e20] -MLP_RATIO = 4 - -# TPU v5p hardware constants for memory estimation -# Constants for TPU v5p -HBM_PER_CHIP_GIB = 95 -CORES_PER_CHIP = 2 -V5P_CORE_OPTIONS = [8, 16, 32, 128, 256, 512] # TPU slices - - -def estimate_bytes( - param_count: int, - hidden_dim: int, - num_layers: int, - batch: int, - seq_len: int, - vocab: int, - optim_mult: int = 3, - dtype_size: int = 4, - fudge_factor: float = 2, -) -> int: - """ - Estimate float32 memory usage (in bytes) for one training step. - Note(Will): I had to do more fudging than expected on this, - but not seems to work ok. - - Parameters: - - hidden_dim: model hidden size - - num_layers: number of Transformer layers - - batch, seq_len: training batch size and sequence length - - vocab: vocabulary size - - optim_mult: optimizer memory multiplier (e.g., 100x for Adam + states) - - dtype_size: bytes per float (4 for float32) - - fudge_factor: safety margin for extra memory - - Returns: - - total estimated memory in bytes - """ - param_bytes = param_count * optim_mult * dtype_size - - act_bytes = (batch * seq_len) * ((hidden_dim * num_layers) + vocab * fudge_factor) - - total_bytes = param_bytes + act_bytes - return int(total_bytes) * fudge_factor +from marin.scaling_laws.isoflop_analysis import ( + CandidateConfig, + DEFAULT_BUDGETS, + IsoFlopSweepConfig, + candidate_configs, + pick_v5p_type, +) -def pick_v5p_type( - config: Qwen3Config, - hidden: int, - layers: int, - batch: int, - seq_len: int, - vocab: int, -) -> str: - """ - Select the smallest TPU v5p slice that fits the model in float32. +@dataclass +class IsoFlopExperimentConfig(IsoFlopSweepConfig): + """Extended config for isoflop experiments with dataset and eval settings. - Returns: - - TPU slice name, e.g., "v5p-8" or "v5p-32" + Inherits core sweep parameters from IsoFlopSweepConfig and adds + experiment-specific settings like tokenized dataset and eval tasks. """ - param_count = compute_num_parameters(config, vocab) - need_bytes = estimate_bytes(param_count, hidden, layers, batch, seq_len, vocab) - chip_bytes = HBM_PER_CHIP_GIB * 1024**3 - chips = math.ceil(need_bytes / chip_bytes) - cores_req = chips * CORES_PER_CHIP - - valid = [c for c in V5P_CORE_OPTIONS if c >= cores_req] - if not valid: - raise ValueError(f"Model too large for available v5p slices (need {cores_req} cores).") - return f"v5p-{min(valid)}" + tokenized_dataset: InputName | str = "" + """Tokenized dataset to train on.""" + eval_tasks: tuple[EvalTaskConfig, ...] | None = None + """Evaluation tasks to run after training (disabled by default).""" -@dataclass -class IsoFlopSweepConfig: - """Configuration for generating ISOFlop sweep steps.""" - - tokenized_dataset: InputName | str - tokenizer: str = "stanford-crfm/marin-tokenizer" - budgets: list[float] = dataclasses.field(default_factory=lambda: DEFAULT_BUDGETS) - seq_len: int = 4096 - steps_per_run: int = 2**16 - flop_tolerance: float = 0.01 - base_hidden_layer_ratio: int = 64 - hidden_head_ratio: int = 128 - lr_constant: float = 0.33 - min_hidden_pow: int = 9 - max_hidden_pow: int = 12 - eval_tasks: tuple[EvalTaskConfig, ...] | None = (MMLU_5_SHOT,) base_optimizer_config: OptimizerConfig = dataclasses.field( default_factory=lambda: CautiousConfig( learning_rate=1.0, # Placeholder @@ -148,6 +79,7 @@ class IsoFlopSweepConfig: decay=0.2, ), ) + base_train_config: SimpleTrainConfig = dataclasses.field( default_factory=lambda: SimpleTrainConfig( resources=ResourceConfig.with_tpu("v5p-8"), @@ -162,167 +94,74 @@ class IsoFlopSweepConfig: ) -def round_to_power_of_two(x: float) -> int: - """Round ``x`` to the nearest power of two.""" - - if x <= 1: - return 1 - return 2 ** math.ceil(math.log2(x)) - - -def compute_total_flops( - batch: int, - num_layers: int, +def _pick_v5p_type_for_model( + config: Qwen3Config, hidden: int, - intermediate: int, - num_kv_heads: int, - num_heads: int, - steps: int, + layers: int, + batch: int, seq_len: int, - vocab_size: int, -) -> float: - """Compute total training FLOPs using Levanter utilities.""" - - flops_per_token = lm_flops_per_token( - hidden, - intermediate, - num_layers, - num_kv_heads, - num_heads, - seq_len, - vocab_size, - glu=True, - ) - return flops_per_token * batch * steps * seq_len - - -def candidate_configs(cfg: IsoFlopSweepConfig, budget: float): - """Yield candidate model configurations within the FLOP budget.""" - - vocab_size = get_vocab_size_for_tokenizer(cfg.tokenizer) - - if budget > 9e18: - step_size = 256 - else: - step_size = 128 - - for hidden_size in range(2**cfg.min_hidden_pow, (2**cfg.max_hidden_pow) + 1, step_size): - hs_pow = math.log2(hidden_size) - intermediate_dim = hidden_size * MLP_RATIO - num_layers = round(hidden_size / (cfg.base_hidden_layer_ratio + (hs_pow * 4) - cfg.min_hidden_pow)) - n_heads = max(1, hidden_size // cfg.hidden_head_ratio) - n_kv_heads = n_heads - - batch_exact = budget / compute_total_flops( - 1, - num_layers, - hidden_size, - intermediate_dim, - n_kv_heads, - n_heads, - cfg.steps_per_run, - cfg.seq_len, - vocab_size, - ) - - batch_size = round_to_power_of_two(batch_exact) - lr = (cfg.lr_constant * math.sqrt(batch_size)) / hidden_size - while lr > 0.01: - batch_size //= 2 - lr = (cfg.lr_constant * math.sqrt(batch_size)) / hidden_size - b2 = 0.98 ** (batch_size / 128) # https://arxiv.org/pdf/2507.07101 - - if batch_size < 8: - continue - - steps_exact = budget / compute_total_flops( - batch_size, - num_layers, - hidden_size, - intermediate_dim, - n_kv_heads, - n_heads, - 1, - cfg.seq_len, - vocab_size, - ) - train_steps = round(steps_exact) - - achieved_flops = compute_total_flops( - batch_size, - num_layers, - hidden_size, - intermediate_dim, - n_kv_heads, - n_heads, - train_steps, - cfg.seq_len, - vocab_size, - ) - - if abs(achieved_flops - budget) / budget > cfg.flop_tolerance: - continue - - yield (hidden_size, intermediate_dim, num_layers, n_heads, n_kv_heads, batch_size, train_steps, lr, b2) + vocab: int, +) -> str: + """Select the smallest TPU v5p slice that fits the model in float32.""" + param_count = compute_num_parameters(config, vocab) + return pick_v5p_type(param_count, hidden, layers, batch, seq_len, vocab) def generate_isoflop_steps( - config: IsoFlopSweepConfig, + config: IsoFlopExperimentConfig, experiment_name: str, -) -> tuple[list[ExecutorStep], list[tuple[float, int, int, int, int]]]: +) -> tuple[list[ExecutorStep], list[CandidateConfig]]: """Generate executor steps for an ISOFlop sweep. Returns: A tuple of: - steps: Training and evaluation ExecutorSteps for the sweep. - - metadata: (budget, hidden_size, num_layers, batch_size, train_steps) for each training run. + - candidates: CandidateConfig for each training run (contains budget, hidden_size, + num_layers, batch_size, train_steps, learning_rate, etc.) """ train_steps_list: list[ExecutorStep] = [] eval_steps: list[ExecutorStep] = [] - metadata: list[tuple[float, int, int, int, int]] = [] + candidates: list[CandidateConfig] = [] vocab_size = get_vocab_size_for_tokenizer(config.tokenizer) for budget in config.budgets: - for ( - hidden_size, - intermediate_dim, - num_layers, - n_heads, - n_kv_heads, - batch_size, - train_steps, - lr, - b2, - ) in candidate_configs(config, budget): + for candidate in candidate_configs(config, budget, vocab_size): model_cfg = Qwen3Config( max_seq_len=config.seq_len, - hidden_dim=hidden_size, - intermediate_dim=intermediate_dim, - num_heads=n_heads, - num_kv_heads=n_kv_heads, - num_layers=num_layers, + hidden_dim=candidate.hidden_size, + intermediate_dim=candidate.intermediate_dim, + num_heads=candidate.num_heads, + num_kv_heads=candidate.num_kv_heads, + num_layers=candidate.num_layers, rope=Llama3RotaryEmbeddingsConfig(), ) - tpu_type = pick_v5p_type( + tpu_type = _pick_v5p_type_for_model( config=model_cfg, - hidden=hidden_size, - layers=num_layers, - batch=batch_size, + hidden=candidate.hidden_size, + layers=candidate.num_layers, + batch=candidate.batch_size, seq_len=config.seq_len, vocab=vocab_size, ) - optimizer_cfg = replace(config.base_optimizer_config, learning_rate=lr, beta2=b2) + optimizer_cfg = replace( + config.base_optimizer_config, + learning_rate=candidate.learning_rate, + beta2=candidate.beta2, + ) train_cfg = replace( config.base_train_config, - train_batch_size=batch_size, - learning_rate=lr, - num_train_steps=train_steps, + train_batch_size=candidate.batch_size, + learning_rate=candidate.learning_rate, + num_train_steps=candidate.train_steps, resources=ResourceConfig.with_tpu(tpu_type), optimizer_config=optimizer_cfg, ) - run_name = f"isoflop-{budget:.0e}-d{hidden_size}-L{num_layers}-B{batch_size}-{experiment_name}" + run_name = ( + f"isoflop-{budget:.0e}-d{candidate.hidden_size}-" + f"L{candidate.num_layers}-B{candidate.batch_size}-{experiment_name}" + ) train_step = default_train( name=run_name, tokenized=config.tokenized_dataset, @@ -331,14 +170,14 @@ def generate_isoflop_steps( eval_harness_tasks=[], tags=( f"FLOPs={budget:.1e}", - f"d={hidden_size}", - f"L={num_layers}", - f"B={batch_size}", - f"steps={train_steps}", + f"d={candidate.hidden_size}", + f"L={candidate.num_layers}", + f"B={candidate.batch_size}", + f"steps={candidate.train_steps}", f"tpu={tpu_type}", ), ) - metadata.append((budget, hidden_size, num_layers, batch_size, train_steps)) + candidates.append(candidate) # Reuse checkpoints by pinning every sweep run to a deterministic directory. static_output_path = os.path.join( "checkpoints", @@ -358,18 +197,30 @@ def generate_isoflop_steps( eval_steps.append(eval_step) all_steps: list[ExecutorStep] = [*train_steps_list, *eval_steps] - return all_steps, metadata + return all_steps, candidates def generate_isoflop_sweep( tokenized: InputName | ExecutorStep | LMMixtureDatasetConfig, experiment_name: str, **kwargs, -) -> list[ExecutorStep]: - sweep_cfg = IsoFlopSweepConfig(tokenized_dataset=tokenized, **kwargs) - steps, metadata = generate_isoflop_steps(sweep_cfg, experiment_name) +) -> tuple[list[ExecutorStep], list[CandidateConfig]]: + """Generate an ISOFlop sweep for a tokenized dataset. + + Args: + tokenized: Tokenized dataset to train on. + experiment_name: Name suffix for the experiment (e.g., 'nemo', 'dclm'). + **kwargs: Additional arguments passed to IsoFlopExperimentConfig. + + Returns: + A tuple of: + - steps: Training and evaluation ExecutorSteps for the sweep. + - candidates: CandidateConfig for each training run with full config details. + """ + sweep_cfg = IsoFlopExperimentConfig(tokenized_dataset=tokenized, **kwargs) + steps, candidates = generate_isoflop_steps(sweep_cfg, experiment_name) - return steps, metadata + return steps, candidates dclm_tokenized = dataclasses.replace( diff --git a/experiments/tootsie/exp654_scaling_tootsie.py b/experiments/tootsie/exp654_scaling_tootsie.py deleted file mode 100644 index d5393270a6..0000000000 --- a/experiments/tootsie/exp654_scaling_tootsie.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2025 The Marin Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from experiments.pretraining_datasets.dclm import dclm_mixture_config_llama3_wrong -from marin.execution.executor import executor_main -import dataclasses -import logging -from collections.abc import Sequence - -from levanter.data.text import LMMixtureDatasetConfig -from levanter.models.llama import LlamaConfig - -from experiments.defaults import default_train -from experiments.llama import llama_1_4b -from experiments.simple_train_config import SimpleTrainConfig -from fray.cluster import ResourceConfig -from marin.execution.executor import ExecutorStep, InputName - -DEFAULT_MODEL_CONFIG = LlamaConfig( - max_seq_len=4096, - hidden_dim=2048, - intermediate_dim=7168, - num_heads=16, - num_kv_heads=8, - num_layers=16, -) - -# WSD-S training configuration -DEFAULT_SWEEP_TRAIN_CONFIG = SimpleTrainConfig( - resources=ResourceConfig.with_tpu("v4-128"), - train_batch_size=1024, - learning_rate=1e-3, # will be replaced in the scaling law suite - weight_decay=0.1, - # https://arxiv.org/pdf/2412.04403 gets 4 points per run. this gives us 5 - num_train_steps=50000, # 4096 * 1024 * 50000 = ~200B tokens - cycle_length=10000, # 5 cycles with 10000 steps/cycle - steps_per_eval=10000, # same as cycle length - warmup=1000, # initial warmup - decay=0.1, # 10% decay - lr_schedule="inv", # inv decay -) - - -# TODO(dlwh): in an old levanter branch (wandb_sweeps) i had fancier sweep generation stuff for doing surgery on the -# config. Consider using that. - - -def scaling_law_suite( - sweep_name: str, - tokenized: InputName | ExecutorStep | LMMixtureDatasetConfig, - widths: Sequence[int] = (512, 768, 1024, 1536, 2048), - base_model_config: LlamaConfig = llama_1_4b, - tags: Sequence[str] = (), - *, - intermediate_scale: float = 8, - training_config: SimpleTrainConfig = DEFAULT_SWEEP_TRAIN_CONFIG, - base_lr: float = 3e-4 * 4096, - max_lr: float = 5e-3, -) -> Sequence[ExecutorStep]: - """ - Provides width-wise scaling suite using WSD-S (or other) training configurations. - - Assumptions (consistent with llama 3): - * 128 head_dim - * 8 key-value heads unless that doesn't work with head_dim = 128 - * intermediate_dim = _round_to_multiple(intermediate_scale * width, 128) - * all widths are divisible by 128 - * peak lr is scaled to be base_lr / width, but clamped to max_lr - - Args: - sweep_name: prefix for the sweep name. runs will be named {sweep_name}-{width}-{hash} - base_model_config: base model configuration. Sweep will be generated by varying the width. - tokenized: input data for training - widths: range of widths to sweep over - training_config: training configuration - - References: - * default widths are from https://arxiv.org/pdf/2412.04403 table 1 (plus 512) - * incredibly wide intermediate_scale is based on the same table - * base_lr is based on llama 3 (https://arxiv.org/pdf/2407.21783 table 3) - * max_lr is a reasonable value that is not too high - * default model config (1_4b) gives the number of layers used in https://arxiv.org/pdf/2412.04403 table 1 - * lr scaling is based on µP/µTransfer: https://arxiv.org/pdf/2203.03466 where generally speaking, lr should - be scaled down by the width of the model. - """ - - steps = [] - for w in widths: - intermediate_dim = _round_to_multiple(intermediate_scale * w, 128) - head_size = 128 # keeping this 128 means we can use splash attention - num_heads = w // head_size - num_kv_heads = min(num_heads, 8) - assert num_heads * head_size == w, f"Number of heads must divide width: {w} % {head_size} != 0" - - # if num_kv_heads doesn't divide num_heads, we need to adjust num_kv_heads - if num_heads % num_kv_heads != 0: - num_kv_heads = num_heads - - model_config = dataclasses.replace( - base_model_config, - hidden_dim=w, - intermediate_dim=intermediate_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - ) - - lr = min(base_lr / w, max_lr) - training_config = dataclasses.replace(training_config, learning_rate=lr) - - logging.info(f"Creating training step for {sweep_name}-{w} with width {w} and lr {lr}") - - steps.append( - default_train( - name=f"{sweep_name}-{w}", - tokenized=tokenized, - model_config=model_config, - train_config=training_config, - tags=tags, - ) - ) - return steps - - -def _round_to_multiple(x, multiple): - return int(multiple * round(x / multiple)) - - -TAG = ["654_scaling_tootsie"] - -suite = scaling_law_suite(sweep_name="tootsie-scaling", tokenized=dclm_mixture_config_llama3_wrong, tags=TAG) - -if __name__ == "__main__": - executor_main( - steps=[ - *suite, - ], - description="scaling law suite to predict performance of 8B model on DCLM mix", - ) diff --git a/lib/marin/pyproject.toml b/lib/marin/pyproject.toml index e0cb23b441..8f8e1f1c91 100644 --- a/lib/marin/pyproject.toml +++ b/lib/marin/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "google-cloud-storage", "google-cloud-storage-transfer", "jax==0.8.0", # vllm-tpu currently requires this exact version + "jaxopt>=0.8.3", "haliax", "levanter[serve]", "lm-eval@git+https://github.com/stanford-crfm/lm-evaluation-harness@d5e3391f22cde186c827674d5c3ec7c5f4fe0cab", diff --git a/lib/marin/src/marin/scaling_laws/__init__.py b/lib/marin/src/marin/scaling_laws/__init__.py index 731b4c72e7..12ffbfaecd 100644 --- a/lib/marin/src/marin/scaling_laws/__init__.py +++ b/lib/marin/src/marin/scaling_laws/__init__.py @@ -11,3 +11,49 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from marin.scaling_laws.isoflop_analysis import ( + DEFAULT_BUDGETS, + CandidateConfig, + IsoFlopAnalysisConfig, + IsoFlopAnalysisResult, + IsoFlopSweepConfig, + candidate_configs, + isoflop_analysis_step, + pick_v5p_type, + predict_optimal_config, + predict_optimal_configs_for_budgets, + run_isoflop_analysis, +) +from marin.scaling_laws.scaling_ladder import ( + ScalingLadderRungConfig, + ScalingLadderSuite, + scaling_ladder_rung_step, + scaling_ladder_suite, +) + +# Plotting functions are imported separately to avoid plotly dependency in core module +# from marin.scaling_laws.scaling_plots import create_isoflop_plot, create_scaling_plot, save_plots + +__all__ = [ + # Primary interface (ExecutorStep factories) + "isoflop_analysis_step", + "scaling_ladder_suite", + "scaling_ladder_rung_step", + # Programmatic interface + "run_isoflop_analysis", + # Dataclasses + "CandidateConfig", + "IsoFlopAnalysisConfig", + "IsoFlopAnalysisResult", + "IsoFlopSweepConfig", + "ScalingLadderSuite", + "ScalingLadderRungConfig", + # Constants + "DEFAULT_BUDGETS", + # Utilities + "candidate_configs", + "pick_v5p_type", + "predict_optimal_config", + "predict_optimal_configs_for_budgets", +] diff --git a/lib/marin/src/marin/scaling_laws/create_ladder_suite.py b/lib/marin/src/marin/scaling_laws/create_ladder_suite.py deleted file mode 100644 index 5f147e7aef..0000000000 --- a/lib/marin/src/marin/scaling_laws/create_ladder_suite.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright 2025 The Marin Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Creates a suite of runs for scaling laws- based on https://arxiv.org/pdf/2412.04403 and https://github.com/marin-community/marin/issues/646. -""" - -import dataclasses -import logging -from collections.abc import Sequence - -from fray.cluster import ResourceConfig -from levanter.data.text import LMMixtureDatasetConfig -from levanter.models.llama import LlamaConfig - -from experiments.defaults import default_train -from experiments.evals.task_configs import CORE_TASKS_PLUS_MMLU -from experiments.llama import llama_1_4b -from experiments.simple_train_config import SimpleTrainConfig -from marin.execution.executor import ExecutorStep, InputName - -DEFAULT_MODEL_CONFIG = LlamaConfig( - max_seq_len=4096, - hidden_dim=2048, - intermediate_dim=7168, - num_heads=16, - num_kv_heads=8, - num_layers=16, -) - -WS_EMA_DEFAULT_TRAIN_CONFIG = SimpleTrainConfig( - resources=ResourceConfig.with_tpu("v4-128", slice_count=1), - train_batch_size=1024, - learning_rate=1e-3, # placeholder, this will be replaced in the scaling law suite - weight_decay=0.1, - # https://arxiv.org/pdf/2412.04403 gets 4 points per run. this gives us 5 - num_train_steps=50000, # 4096 * 1024 * 50000 = ~200B tokens - warmup=1000, # initial warmup - decay=0.0, # no decay - lr_schedule="constant", - ema_beta=0.995, - steps_per_eval=500, - steps_per_task_eval=500, -) - - -def scaling_law_suite( - sweep_name: str, - tokenized: InputName | ExecutorStep | LMMixtureDatasetConfig, - widths: Sequence[int] = (512, 768, 1024, 1536, 2048), - base_model_config: LlamaConfig = llama_1_4b, - tags: Sequence[str] = (), - *, - intermediate_scale: float = 4, - training_config: SimpleTrainConfig = WS_EMA_DEFAULT_TRAIN_CONFIG, - base_lr: float = 3e-4 * 4096, - max_lr: float = 5e-3, -) -> Sequence[ExecutorStep]: - """ - Provides width-wise scaling suite using WSD-S (or other) training configurations. - - Assumptions (consistent with llama 3): - * 128 head_dim - * 8 key-value heads unless that doesn't work with head_dim = 128 - * intermediate_dim = _round_to_multiple(intermediate_scale * width, 128) - * all widths are divisible by 128 - * peak lr is scaled to be base_lr / width, but clamped to max_lr - - Args: - sweep_name: prefix for the sweep name. runs will be named {sweep_name}-{width}-{hash} - base_model_config: base model configuration. Sweep will be generated by varying the width. - tokenized: input data for training - widths: range of widths to sweep over - training_config: training configuration - - References: - * default widths are from https://arxiv.org/pdf/2412.04403 table 1 (plus 512) - * intermediate scale is 4; should be 8 based on https://arxiv.org/pdf/2412.04403 table 1, - but we ultimately decided to go with a smaller value based on - https://arxiv.org/pdf/2407.21783 table 3 since 8 seemed large compared to - other works. - * base_lr is based on llama 3 (https://arxiv.org/pdf/2407.21783 table 3) - * max_lr is a reasonable value that is not too high - * default model config (1_4b) gives the number of layers used in https://arxiv.org/pdf/2412.04403 table 1 - * lr scaling is based on µP/µTransfer: https://arxiv.org/pdf/2203.03466 where generally speaking, lr should - be scaled down by the width of the model. - """ - - steps = [] - for w in widths: - intermediate_dim = _round_to_multiple(intermediate_scale * w, 128) - head_size = 128 # keeping this 128 means we can use splash attention - num_heads = w // head_size - num_kv_heads = min(num_heads, 8) - assert num_heads * head_size == w, f"Number of heads must divide width: {w} % {head_size} != 0" - - # if num_kv_heads doesn't divide num_heads, we need to adjust num_kv_heads - if num_heads % num_kv_heads != 0: - num_kv_heads = num_heads - - model_config = dataclasses.replace( - base_model_config, - hidden_dim=w, - intermediate_dim=intermediate_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - ) - - lr = min(base_lr / w, max_lr) - training_config = dataclasses.replace(training_config, learning_rate=lr) - - logging.info(f"Creating training step for {sweep_name}-{w} with width {w} and lr {lr}") - - steps.append( - default_train( - name=f"{sweep_name}-{w}", - tokenized=tokenized, - model_config=model_config, - train_config=training_config, - tags=tags, - eval_harness_tasks=CORE_TASKS_PLUS_MMLU, - ) - ) - return steps - - -def _round_to_multiple(x, multiple): - return int(multiple * round(x / multiple)) diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py new file mode 100644 index 0000000000..fb6d30c200 --- /dev/null +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -0,0 +1,972 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""IsoFLOP analysis for finding compute-optimal training configurations. + +Primary usage - create an ExecutorStep for your pipeline: + + from marin.scaling_laws import isoflop_analysis_step + + analysis = isoflop_analysis_step( + name="my-scaling-analysis", + training_runs=my_training_steps, # list of ExecutorStep + ) + +The step will: +1. Read eval metrics from completed training runs +2. Fit scaling laws to find compute-optimal token counts +3. Save plots and results to the output path + +For programmatic use, see `run_isoflop_analysis()` which returns a `IsoFlopAnalysisResult`. +""" + +import json +import logging +import math +import os +import re +from collections.abc import Iterator, Sequence +from dataclasses import asdict, dataclass, field + +import fsspec +import jax.numpy as jnp +import pandas as pd +from jaxopt import ScipyMinimize +from levanter.utils.flop_utils import lm_flops_per_token + +from marin.execution.executor import ExecutorStep, InputName, output_path_of, this_output_path +from marin.scaling_laws.eval_metrics_reader import ( + EvalMetricsAnalysisConfig, + extract_run_name_from_path, + read_metrics_dataframe, +) + + +logger = logging.getLogger(__name__) + +# ---------------- Constants ---------------- +DEFAULT_METRIC_KEY = "eval/paloma/c4_en/bpb" +SEQ_LEN = 4096 +CANON_LABELS = ["nemo", "comma", "dclm"] + +# ---------------- IsoFLOP Sweep Constants ---------------- +DEFAULT_BUDGETS = [1e18, 3e18, 6e18, 1e19, 3e19, 6e19, 1e20] +MLP_RATIO = 4 + +# TPU v5p hardware constants for memory estimation +HBM_PER_CHIP_GIB = 95 +CORES_PER_CHIP = 2 +V5P_CORE_OPTIONS = [8, 16, 32, 128, 256, 512] + + +# ---------------- IsoFLOP Sweep Config ---------------- +@dataclass +class IsoFlopSweepConfig: + """Configuration for generating ISOFlop sweep candidate configs. + + This config controls the model architecture search space and training + hyperparameters for isoflop experiments. + """ + + tokenizer: str = "stanford-crfm/marin-tokenizer" + """Tokenizer to use (needed for vocab size).""" + + budgets: list[float] = field(default_factory=lambda: DEFAULT_BUDGETS.copy()) + """List of FLOP budgets to generate configs for.""" + + seq_len: int = 4096 + """Sequence length for training.""" + + steps_per_run: int = 2**16 + """Target number of training steps per run.""" + + flop_tolerance: float = 0.01 + """Tolerance for matching FLOP budget (relative error).""" + + base_hidden_layer_ratio: int = 64 + """Base ratio for hidden_dim to num_layers calculation.""" + + hidden_head_ratio: int = 128 + """Ratio for hidden_dim to num_heads calculation.""" + + lr_constant: float = 0.33 + """Constant for learning rate calculation: lr = (lr_constant * sqrt(batch)) / hidden_dim.""" + + min_hidden_pow: int = 9 + """Minimum hidden dimension as power of 2 (2^9 = 512).""" + + max_hidden_pow: int = 12 + """Maximum hidden dimension as power of 2 (2^12 = 4096).""" + + +# ---------------- Candidate Config ---------------- + + +@dataclass +class CandidateConfig: + """A candidate model/training configuration from the isoflop sweep. + + This dataclass contains all the information needed to create a training run. + Callers are responsible for converting this to their specific config format + (e.g., SimpleTrainConfig, Qwen3Config). + """ + + hidden_size: int + intermediate_dim: int + num_layers: int + num_heads: int + num_kv_heads: int + batch_size: int + train_steps: int + learning_rate: float + beta2: float + tokens: float # total tokens = batch_size * train_steps * seq_len + flops_budget: float = 0.0 # the FLOP budget this config was generated for + + +# ---------------- Candidate Config Generation ---------------- + + +def round_to_power_of_two(x: float) -> int: + """Round ``x`` to the nearest power of two.""" + if x <= 1: + return 1 + return 2 ** math.ceil(math.log2(x)) + + +def compute_total_flops( + batch: int, + num_layers: int, + hidden: int, + intermediate: int, + num_kv_heads: int, + num_heads: int, + steps: int, + seq_len: int, + vocab_size: int, +) -> float: + """Compute total training FLOPs using Levanter utilities.""" + flops_per_token = lm_flops_per_token( + hidden, + intermediate, + num_layers, + num_kv_heads, + num_heads, + seq_len, + vocab_size, + glu=True, + ) + return flops_per_token * batch * steps * seq_len + + +def estimate_memory_bytes( + param_count: int, + hidden_dim: int, + num_layers: int, + batch: int, + seq_len: int, + vocab: int, + optim_mult: int = 3, + dtype_size: int = 4, + fudge_factor: float = 2, +) -> int: + """ + Estimate float32 memory usage (in bytes) for one training step. + + Parameters: + - param_count: number of model parameters + - hidden_dim: model hidden size + - num_layers: number of Transformer layers + - batch, seq_len: training batch size and sequence length + - vocab: vocabulary size + - optim_mult: optimizer memory multiplier (e.g., 3x for Adam + states) + - dtype_size: bytes per float (4 for float32) + - fudge_factor: safety margin for extra memory + + Returns: + - total estimated memory in bytes + """ + param_bytes = param_count * optim_mult * dtype_size + act_bytes = (batch * seq_len) * ((hidden_dim * num_layers) + vocab * fudge_factor) + total_bytes = param_bytes + act_bytes + return int(total_bytes * fudge_factor) + + +def pick_v5p_type( + param_count: int, + hidden: int, + layers: int, + batch: int, + seq_len: int, + vocab: int, +) -> str: + """ + Select the smallest TPU v5p slice that fits the model in float32. + + Returns: + - TPU slice name, e.g., "v5p-8" or "v5p-32" + """ + need_bytes = estimate_memory_bytes(param_count, hidden, layers, batch, seq_len, vocab) + chip_bytes = HBM_PER_CHIP_GIB * 1024**3 + chips = math.ceil(need_bytes / chip_bytes) + cores_req = chips * CORES_PER_CHIP + + valid = [c for c in V5P_CORE_OPTIONS if c >= cores_req] + if not valid: + raise ValueError(f"Model too large for available v5p slices (need {cores_req} cores).") + + return f"v5p-{min(valid)}" + + +def candidate_configs( + cfg: IsoFlopSweepConfig, + budget: float, + vocab_size: int, +) -> Iterator[CandidateConfig]: + """Yield candidate model configurations within the FLOP budget. + + Args: + cfg: IsoFlopSweepConfig with search parameters + budget: Target FLOP budget + vocab_size: Vocabulary size for the tokenizer + + Yields: + CandidateConfig objects for each valid configuration + """ + if budget > 9e18: + step_size = 256 + else: + step_size = 128 + + for hidden_size in range(2**cfg.min_hidden_pow, (2**cfg.max_hidden_pow) + 1, step_size): + hs_pow = math.log2(hidden_size) + intermediate_dim = hidden_size * MLP_RATIO + num_layers = round(hidden_size / (cfg.base_hidden_layer_ratio + (hs_pow * 4) - cfg.min_hidden_pow)) + n_heads = max(1, hidden_size // cfg.hidden_head_ratio) + n_kv_heads = n_heads + + batch_exact = budget / compute_total_flops( + 1, + num_layers, + hidden_size, + intermediate_dim, + n_kv_heads, + n_heads, + cfg.steps_per_run, + cfg.seq_len, + vocab_size, + ) + + batch_size = round_to_power_of_two(batch_exact) + lr = (cfg.lr_constant * math.sqrt(batch_size)) / hidden_size + while lr > 0.01: + batch_size //= 2 + lr = (cfg.lr_constant * math.sqrt(batch_size)) / hidden_size + b2 = 0.98 ** (batch_size / 128) # https://arxiv.org/pdf/2507.07101 + + if batch_size < 8: + continue + + steps_exact = budget / compute_total_flops( + batch_size, + num_layers, + hidden_size, + intermediate_dim, + n_kv_heads, + n_heads, + 1, + cfg.seq_len, + vocab_size, + ) + train_steps = round(steps_exact) + + achieved_flops = compute_total_flops( + batch_size, + num_layers, + hidden_size, + intermediate_dim, + n_kv_heads, + n_heads, + train_steps, + cfg.seq_len, + vocab_size, + ) + + if abs(achieved_flops - budget) / budget > cfg.flop_tolerance: + continue + + tokens = batch_size * train_steps * cfg.seq_len + + yield CandidateConfig( + hidden_size=hidden_size, + intermediate_dim=intermediate_dim, + num_layers=num_layers, + num_heads=n_heads, + num_kv_heads=n_kv_heads, + batch_size=batch_size, + train_steps=train_steps, + learning_rate=lr, + beta2=b2, + tokens=tokens, + flops_budget=budget, + ) + + +# ---------------- Helpers ---------------- + + +def parse_isoflop_run_name(run_name: str) -> dict | None: + """Parse metadata from isoflop run name. + + Expected format: isoflop-{budget}-d{hidden}-L{layers}-B{batch}-{experiment_name} + Optionally with a trailing - which is ignored. + E.g., 'isoflop-1e+19-d2048-L16-B1024-nemo-wider-depth-adapt' + or 'isoflop-1e+19-d2048-L16-B1024-nemo-wider-depth-adapt-a1b2c3' + + Returns dict with: flops, d, L, B, experiment_name or None if parsing fails. + """ + # Strip optional - suffix + run_name = re.sub(r"-[0-9a-fA-F]{6}$", "", run_name) + + pattern = r"isoflop-([0-9.e+]+)-d(\d+)-L(\d+)-B(\d+)-(.+)" + match = re.match(pattern, run_name) + if not match: + return None + + flops_str, d, L, B, exp_name = match.groups() + return { + "flops": float(flops_str), + "d": int(d), + "L": int(L), + "B": int(B), + "experiment_name": exp_name, + } + + +def robust_quad_logx(x: jnp.ndarray, y: jnp.ndarray, delta: float = 1.0) -> tuple[float, float, float]: + """Fit a robust quadratic in log10(x) space using Huber loss. + + Returns (a, b, c) coefficients for: loss = a * log10(x)^2 + b * log10(x) + c + """ + L = jnp.log10(x) + + def huber(residual): + abs_r = jnp.abs(residual) + quad = 0.5 * residual**2 + linear = delta * (abs_r - 0.5 * delta) + return jnp.where(abs_r <= delta, quad, linear) + + def objective(params): + a, b, c = params + pred = a * L**2 + b * L + c + residuals = y - pred + return jnp.sum(huber(residuals)) + + opt = ScipyMinimize(fun=objective, method="BFGS", value_and_grad=False) + init = jnp.array(jnp.polyfit(L, y, 2)) if len(L) >= 3 else jnp.array([0.0, *jnp.polyfit(L, y, 1)]) + result = opt.run(init_params=init).params + return float(result[0]), float(result[1]), float(result[2]) + + +def _compute_optimal_params(flops: float, tokens: float) -> float: + """Compute optimal parameters from C = 6 * N * P approximation.""" + return flops / (6 * tokens) + + +def _find_nearest_config(df: pd.DataFrame, flops: float, tokens: float) -> dict: + """Find the nearest actual config from the dataframe to use as template.""" + sub = df[df.flops == flops] + if sub.empty: + sub = df + idx = (sub.tokens - tokens).abs().argmin() + row = sub.iloc[idx] + + run_name = row["name"] + meta = parse_isoflop_run_name(run_name) + + return { + "hidden_dim": meta["d"] if meta else 0, + "num_layers": meta["L"] if meta else 0, + "batch_size": meta["B"] if meta else 0, + "params": row.get("params", _compute_optimal_params(flops, tokens)), + } + + +# ---------------- Core Analysis ---------------- + + +def fit_scaling_laws( + df: pd.DataFrame, +) -> tuple[list[dict], dict[str, tuple[float, float]], dict[tuple[str, float], tuple[float, float, float]]]: + """ + Fit scaling laws and extract optimal configurations. + + Args: + df: DataFrame with columns: tokens, loss, flops, params, name, label + + Returns: + - minima_records: List of dicts with optimal config info per (label, flops) + - scaling_fits: Dict of {label: (alpha, A)} for N* ~ A * C^alpha + - fit_curves: Dict of {(label, flops): (a, b, c)} quadratic coefficients for plotting + """ + if df is None or df.empty: + return [], {}, {} + + present = list(dict.fromkeys(df["label"].tolist())) + datasets = [lab for lab in CANON_LABELS if lab in present] + [lab for lab in present if lab not in CANON_LABELS] + + buckets = sorted(df.flops.unique()) + + minima_records = [] + fit_curves = {} + + # Fit quadratic for each (label, budget) and find minima + for lab in datasets: + for C in buckets: + sub = df[(df.flops == C) & (df.label == lab)].sort_values("tokens") + if sub.empty: + continue + + # Robust quadratic fit in log10(tokens) + a, b, c = robust_quad_logx(jnp.array(sub.tokens.values), jnp.array(sub.loss.values)) + fit_curves[(lab, C)] = (a, b, c) + + if a == 0: + continue + + # Compute minimum + L_opt = -b / (2 * a) + N_star = float(10**L_opt) + loss_opt = float(a * L_opt**2 + b * L_opt + c) + + # Find nearest actual config for template + nearest = _find_nearest_config(sub, C, N_star) + + minima_records.append( + { + "label": lab, + "flops": float(C), + "optimal_tokens": N_star, + "loss_at_optimal": loss_opt, + "hidden_dim": nearest["hidden_dim"], + "num_layers": nearest["num_layers"], + "batch_size": nearest["batch_size"], + "optimal_params": float(nearest["params"]), + } + ) + + # Fit scaling law N* ~ A * C^alpha per dataset + scaling_fits = {} + by_lab = {} + for rec in minima_records: + by_lab.setdefault(rec["label"], []).append(rec) + + for lab in datasets: + recs = by_lab.get(lab, []) + if len(recs) < 2: + continue + + recs = sorted(recs, key=lambda r: r["flops"]) + Cs = jnp.array([r["flops"] for r in recs]) + Ns = jnp.array([r["optimal_tokens"] for r in recs]) + + alpha, logA = jnp.polyfit(jnp.log10(Cs), jnp.log10(Ns), 1) + A = float(10**logA) + alpha = float(alpha) + scaling_fits[lab] = (alpha, A) + + # Augment minima records with scaling fit params + for rec in recs: + rec["scaling_alpha"] = alpha + rec["scaling_A"] = A + + return minima_records, scaling_fits, fit_curves + + +def transform_metrics_for_isoflop( + df: pd.DataFrame, + metric_key: str, + label_map: dict[str, str] | None = None, +) -> pd.DataFrame: + """Transform raw metrics DataFrame into isoflop analysis format. + + Takes the generic metrics DataFrame from read_metrics_dataframe() and + transforms it into the format expected by the analysis: + columns: tokens, loss, flops, params, name, label + + Args: + df: Raw metrics DataFrame from read_metrics_dataframe() + metric_key: Which metric column to use for loss + label_map: Optional mapping from experiment_name -> display label + + Returns: + Transformed DataFrame ready for fit_scaling_laws() + """ + if df.empty: + return pd.DataFrame(columns=["tokens", "loss", "flops", "params", "name", "label"]) + + # Get final metrics for each run (max step) + final_metrics = df.loc[df.groupby("run_path")["step"].idxmax()].copy() + + records = [] + for _, row in final_metrics.iterrows(): + run_path = row["run_path"] + run_name = extract_run_name_from_path(run_path) + + # Parse metadata from run name + meta = parse_isoflop_run_name(run_name) + if meta is None: + logger.warning(f"Could not parse metadata from run name: {run_name}") + continue + + flops = meta["flops"] + if flops < 1e18: + continue + + # Calculate tokens = steps * batch * seq_len + steps = row["step"] + batch = meta["B"] + tokens = steps * batch * SEQ_LEN + + # Get loss from the metric column + loss = row.get(metric_key) + if loss is None or pd.isna(loss): + logger.warning(f"Missing metric {metric_key} for run {run_name}") + continue + + params = row.get("parameter_count") + if params is None or pd.isna(params): + params = None + + # Determine label + exp_name = meta["experiment_name"] + if label_map and exp_name in label_map: + label = label_map[exp_name] + else: + label = exp_name + for canon in CANON_LABELS: + if canon in exp_name.lower(): + label = canon + break + + records.append( + dict( + tokens=tokens, + loss=loss, + flops=flops, + params=params, + name=run_name, + label=label, + ) + ) + + return pd.DataFrame.from_records(records) + + +# ---------------- Predict Optimal Config ---------------- + + +def predict_optimal_config( + scaling_fits: dict[str, tuple[float, float]], + target_flops: float, + label: str, + sweep_config: IsoFlopSweepConfig | None = None, + vocab_size: int = 128256, +) -> CandidateConfig | None: + """Predict optimal training config for a target compute budget using fitted scaling laws. + + This function: + 1. Uses the scaling fit (N* ~ A * C^alpha) to predict optimal tokens for target_flops + 2. Generates candidate configs for the target budget using candidate_configs() + 3. Selects the candidate whose token count is closest to the predicted optimal + + Args: + scaling_fits: Dict of {label: (alpha, A)} from scaling ladder result. + target_flops: Target compute budget in FLOPs. + label: Dataset/experiment label to use for scaling fit. + sweep_config: Optional IsoFlopSweepConfig. If None, uses defaults. + vocab_size: Vocabulary size (default: 128256 for marin tokenizer). + + Returns: + CandidateConfig for the predicted optimal, or None if label not in fits + or no valid candidates found. + """ + if label not in scaling_fits: + logger.warning(f"Label '{label}' not found in scaling fits") + return None + + alpha, A = scaling_fits[label] + optimal_tokens = A * (target_flops**alpha) + + logger.info(f"Predicted optimal tokens for {target_flops:.2e} FLOPs: {optimal_tokens:.2e}") + + # Use default config if none provided + if sweep_config is None: + sweep_config = IsoFlopSweepConfig() + + # Generate candidates for this budget + candidates = list(candidate_configs(sweep_config, target_flops, vocab_size)) + + if not candidates: + logger.warning(f"No valid candidates found for budget {target_flops:.2e}") + return None + + # Find candidate closest to optimal token count + best = min(candidates, key=lambda c: abs(c.tokens - optimal_tokens)) + + logger.info( + f"Selected config: d={best.hidden_size}, L={best.num_layers}, " + f"B={best.batch_size}, tokens={best.tokens:.2e} (optimal: {optimal_tokens:.2e})" + ) + + return best + + +def predict_optimal_configs_for_budgets( + scaling_fits: dict[str, tuple[float, float]], + target_budgets: list[float], + label: str, + sweep_config: IsoFlopSweepConfig | None = None, + vocab_size: int = 128256, +) -> list[CandidateConfig]: + """Predict optimal configs for multiple target compute budgets. + + Args: + scaling_fits: Dict of {label: (alpha, A)} from scaling ladder result. + target_budgets: List of target compute budgets in FLOPs. + label: Dataset/experiment label to use for scaling fit. + sweep_config: Optional IsoFlopSweepConfig. If None, uses defaults. + vocab_size: Vocabulary size. + + Returns: + List of CandidateConfig for each budget (skips budgets with no valid config). + """ + configs = [] + for budget in target_budgets: + config = predict_optimal_config(scaling_fits, budget, label, sweep_config, vocab_size) + if config is not None: + configs.append(config) + return configs + + +# ---------------- Result Dataclass ---------------- + + +@dataclass +class IsoFlopAnalysisResult: + """Result from scaling ladder analysis containing optimal configs and analysis data.""" + + configs: list[CandidateConfig] + """List of optimal CandidateConfig for each (label, flops_budget) pair.""" + + scaling_fits: dict[str, tuple[float, float]] + """Per-label scaling fits: {label: (alpha, A)} for N* ~ A * C^alpha.""" + + isoflop_df: pd.DataFrame + """Transformed dataframe used for analysis.""" + + minima_records: list[dict] + """Raw minima records with detailed info for each optimum.""" + + fit_curves: dict[tuple[str, float], tuple[float, float, float]] + """Quadratic fit coefficients {(label, flops): (a, b, c)} for plotting.""" + + def to_json_dict(self) -> dict: + """Convert result to JSON-serializable dict (excludes DataFrame and fit_curves).""" + return { + "configs": [asdict(c) for c in self.configs], + "scaling_fits": {k: list(v) for k, v in self.scaling_fits.items()}, + "minima_records": self.minima_records, + } + + +# ---------------- ExecutorStep Config ---------------- + + +@dataclass(frozen=True) +class IsoFlopAnalysisConfig(EvalMetricsAnalysisConfig): + """Configuration for scaling ladder analysis ExecutorStep.""" + + metric_key: str = DEFAULT_METRIC_KEY + """Metric to use for loss (default: eval/paloma/c4_en/bpb).""" + + label_map: tuple[tuple[str, str], ...] | None = None + """Optional mapping from experiment_name -> display label as tuple of pairs.""" + + save_plots: bool = True + """Whether to save HTML plots to output_path.""" + + upload_to_wandb: bool = True + """Whether to upload plots to WandB.""" + + wandb_entity: str = "marin-community" + """WandB entity for uploads.""" + + wandb_project: str = "marin-analysis" + """WandB project for uploads.""" + + wandb_run_name: str = "scaling-ladder-analysis" + """Name for the WandB run.""" + + +def _run_isoflop_analysis_step(config: IsoFlopAnalysisConfig) -> None: + """Execute scaling ladder analysis (called by ExecutorStep).""" + # Read metrics from training runs + raw_df = read_metrics_dataframe(config) + + if raw_df.empty: + logger.warning("No eval metrics found in training runs") + return + + # Convert label_map tuple to dict if provided + label_map = dict(config.label_map) if config.label_map else None + + # Transform to isoflop analysis format + isoflop_df = transform_metrics_for_isoflop(raw_df, config.metric_key, label_map) + + if isoflop_df.empty: + logger.warning("No valid isoflop data after transformation") + return + + logger.info(f"Loaded {len(isoflop_df)} runs for scaling ladder analysis") + logger.info(f"Labels found: {isoflop_df['label'].unique().tolist()}") + logger.info(f"FLOP budgets: {sorted(isoflop_df['flops'].unique())}") + + # Fit scaling laws + minima_records, scaling_fits, fit_curves = fit_scaling_laws(isoflop_df) + + logger.info(f"Found {len(minima_records)} optimal configurations") + for label, (alpha, A) in scaling_fits.items(): + logger.info(f" {label}: N* = {A:.2e} * C^{alpha:.3f}") + + # Convert minima to CandidateConfigs + configs = [] + for rec in minima_records: + if rec["hidden_dim"] == 0: + continue + candidate = CandidateConfig( + hidden_size=rec["hidden_dim"], + intermediate_dim=rec["hidden_dim"] * MLP_RATIO, + num_layers=rec["num_layers"], + num_heads=max(1, rec["hidden_dim"] // 128), + num_kv_heads=max(1, rec["hidden_dim"] // 128), + batch_size=rec["batch_size"], + train_steps=int(rec["optimal_tokens"] / (rec["batch_size"] * SEQ_LEN)), + learning_rate=(0.33 * math.sqrt(rec["batch_size"])) / rec["hidden_dim"], + beta2=0.98 ** (rec["batch_size"] / 128), + tokens=rec["optimal_tokens"], + flops_budget=rec["flops"], + ) + configs.append(candidate) + + result = IsoFlopAnalysisResult( + configs=configs, + scaling_fits=scaling_fits, + isoflop_df=isoflop_df, + minima_records=minima_records, + fit_curves=fit_curves, + ) + + # Save outputs + fs, _, _ = fsspec.get_fs_token_paths(config.output_path) + fs.makedirs(config.output_path, exist_ok=True) + + # Save result JSON + result_path = os.path.join(config.output_path, "isoflop_analysis_result.json") + with fs.open(result_path, "w") as f: + json.dump(result.to_json_dict(), f, indent=2) + logger.info(f"Saved results to {result_path}") + + # Save plots if enabled + if config.save_plots: + from marin.scaling_laws.scaling_plots import ( + create_isoflop_plot, + create_scaling_plot, + save_plots, + ) + + fig_isoflop = create_isoflop_plot(isoflop_df, minima_records, fit_curves) + fig_scaling = create_scaling_plot(minima_records, scaling_fits) + save_plots(fig_isoflop, fig_scaling, config.output_path) + + # Upload to WandB if enabled + if config.upload_to_wandb: + from marin.scaling_laws.scaling_plots import upload_plots_to_wandb + + upload_plots_to_wandb( + fig_isoflop, + fig_scaling, + entity=config.wandb_entity, + project=config.wandb_project, + run_name=config.wandb_run_name, + ) + + +# ---------------- Primary Export: ExecutorStep Factory ---------------- + + +def isoflop_analysis_step( + name: str, + training_runs: Sequence[ExecutorStep | InputName], + metric_key: str = DEFAULT_METRIC_KEY, + label_map: dict[str, str] | None = None, + save_plots: bool = True, + upload_to_wandb: bool = True, + wandb_entity: str = "marin-community", + wandb_project: str = "marin-analysis", + wandb_run_name: str | None = None, +) -> ExecutorStep: + """Create an ExecutorStep for scaling ladder analysis. + + This is the primary interface for using scaling ladder analysis in a pipeline. + The step will: + 1. Wait for all training runs to complete + 2. Read eval metrics from the training runs + 3. Fit scaling laws to find compute-optimal configurations + 4. Save plots and results to the output path + + Args: + name: Name for this executor step + training_runs: Training run ExecutorSteps or InputNames to analyze + metric_key: Which metric to use for loss (default: eval/paloma/c4_en/bpb) + label_map: Optional mapping from experiment_name -> display label + save_plots: Whether to save HTML plots (default: True) + upload_to_wandb: Whether to upload plots to WandB (default: True) + wandb_entity: WandB entity for uploads + wandb_project: WandB project for uploads + wandb_run_name: Name for WandB run (defaults to step name) + + Returns: + ExecutorStep configured to run the analysis + + Example: + >>> from marin.scaling_laws import isoflop_analysis_step + >>> analysis = scaling_ladder_step( + ... name="my-scaling-analysis", + ... training_runs=my_training_steps, + ... ) + """ + run_paths = [output_path_of(run) if isinstance(run, ExecutorStep) else run for run in training_runs] + + config = IsoFlopAnalysisConfig( + training_runs=run_paths, + output_path=this_output_path(), + metric_key=metric_key, + label_map=tuple(label_map.items()) if label_map else None, + save_plots=save_plots, + upload_to_wandb=upload_to_wandb, + wandb_entity=wandb_entity, + wandb_project=wandb_project, + wandb_run_name=wandb_run_name or name, + ) + + return ExecutorStep( + name=name, + fn=_run_isoflop_analysis_step, + config=config, + description=f"Scaling ladder analysis for {len(training_runs)} training runs", + ) + + +# ---------------- Programmatic Interface ---------------- + + +def run_isoflop_analysis( + training_runs: Sequence[ExecutorStep] | Sequence[str], + metric_key: str = DEFAULT_METRIC_KEY, + label_map: dict[str, str] | None = None, +) -> IsoFlopAnalysisResult: + """Analyze isoflop training runs and return optimal training configurations. + + This is the programmatic interface for scaling ladder analysis. For pipeline + usage, prefer `isoflop_analysis_step()` which returns an ExecutorStep. + + Args: + training_runs: List of ExecutorSteps or path strings to training runs + metric_key: Which metric to use for loss (default: eval/paloma/c4_en/bpb) + label_map: Optional mapping from experiment_name -> display label + + Returns: + IsoFlopAnalysisResult with configs, scaling_fits, and analysis data + """ + # Convert to paths + run_paths = [] + for run in training_runs: + if isinstance(run, ExecutorStep): + run_paths.append(output_path_of(run)) + else: + run_paths.append(run) + + # Read metrics + config = EvalMetricsAnalysisConfig( + training_runs=run_paths, + output_path="analysis/scaling_ladder", + ) + raw_df = read_metrics_dataframe(config) + + if raw_df.empty: + logger.warning("No eval metrics found") + return IsoFlopAnalysisResult( + configs=[], + scaling_fits={}, + isoflop_df=pd.DataFrame(), + minima_records=[], + fit_curves={}, + ) + + # Transform to isoflop format + isoflop_df = transform_metrics_for_isoflop(raw_df, metric_key, label_map) + + if isoflop_df.empty: + logger.warning("No valid isoflop data after transformation") + return IsoFlopAnalysisResult( + configs=[], + scaling_fits={}, + isoflop_df=pd.DataFrame(), + minima_records=[], + fit_curves={}, + ) + + logger.info(f"Transformed {len(isoflop_df)} runs for scaling ladder analysis") + + # Fit scaling laws and extract optima + minima_records, scaling_fits, fit_curves = fit_scaling_laws(isoflop_df) + + # Convert minima records to CandidateConfig objects + configs = [] + for rec in minima_records: + if rec["hidden_dim"] == 0: + continue + candidate = CandidateConfig( + hidden_size=rec["hidden_dim"], + intermediate_dim=rec["hidden_dim"] * MLP_RATIO, + num_layers=rec["num_layers"], + num_heads=max(1, rec["hidden_dim"] // 128), + num_kv_heads=max(1, rec["hidden_dim"] // 128), + batch_size=rec["batch_size"], + train_steps=int(rec["optimal_tokens"] / (rec["batch_size"] * SEQ_LEN)), + learning_rate=(0.33 * math.sqrt(rec["batch_size"])) / rec["hidden_dim"], + beta2=0.98 ** (rec["batch_size"] / 128), + tokens=rec["optimal_tokens"], + flops_budget=rec["flops"], + ) + configs.append(candidate) + + return IsoFlopAnalysisResult( + configs=configs, + scaling_fits=scaling_fits, + isoflop_df=isoflop_df, + minima_records=minima_records, + fit_curves=fit_curves, + ) diff --git a/lib/marin/src/marin/scaling_laws/isoflop_plot.py b/lib/marin/src/marin/scaling_laws/isoflop_plot.py deleted file mode 100644 index 30ce4c56db..0000000000 --- a/lib/marin/src/marin/scaling_laws/isoflop_plot.py +++ /dev/null @@ -1,583 +0,0 @@ -# Copyright 2025 The Marin Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os -import re -from dataclasses import dataclass -from collections.abc import Sequence - -import fsspec -import jax.numpy as jnp -import pandas as pd -import plotly.graph_objects as go -import plotly.io as pio -from jaxopt import ScipyMinimize - -try: - import wandb - - WANDB_AVAILABLE = True -except ImportError: - WANDB_AVAILABLE = False - -from marin.execution.executor import ExecutorStep, InputName, output_path_of -from marin.scaling_laws.eval_metrics_reader import ( - EvalMetricsAnalysisConfig, - create_analysis_step, - extract_run_name_from_path, - read_metrics_dataframe, -) - -try: - from experiments.isoflop_sweep import MARIN_SCALING_SUITES - - ISOFLOP_SWEEP_AVAILABLE = True -except ImportError: - ISOFLOP_SWEEP_AVAILABLE = False - MARIN_SCALING_SUITES = {} - - -logger = logging.getLogger(__name__) - - -# ---------------- Theme ---------------- -pio.templates.default = "plotly_white" - -# ---------------- Constants ---------------- -PALETTE = [ - "#1877F2", - "#F0701A", - "#5A24C7", - "#E42C97", - "#00487C", - "#0EAC96", - "#AB76FF", - "#B50550", - "#0099E6", - "#22085F", - "#783301", -] -MARKERS = [ - "circle", - "square", - "cross", - "x", - "triangle-up", - "triangle-down", - "triangle-left", - "triangle-right", - "pentagon", - "hexagon", - "hexagon2", - "star", - "star-triangle-up", - "star-triangle-down", - "star-square", - "star-diamond", - "hourglass", - "bowtie", -] -DASHES = ["solid", "dot", "dash", "longdash", "dashdot", "longdashdot"] -DEFAULT_METRIC_KEY = "eval/paloma/c4_en/bpb" -SEQ_LEN = 4096 - -_MIN_MARKER = dict(symbol="diamond", size=10, color="#000000") -_SCALE_MARKER = dict(symbol="circle", size=9, color=PALETTE[0]) -_SCALE_LINE = dict(dash="dot", width=2, color=PALETTE[0]) - -REQUIRED_TAGS = {"steps", "B", "FLOPs", "d", "L"} -CANON_LABELS = ["nemo", "comma", "dclm"] # canonical dataset names we detect in displayName - - -# ---------------- Helpers ---------------- - - -def _parse_isoflop_run_name(run_name: str) -> dict | None: - """Parse metadata from isoflop run name. - - Expected format: isoflop-{budget}-d{hidden}-L{layers}-B{batch}-{experiment_name} - Optionally with a trailing - which is ignored. - E.g., 'isoflop-1e+19-d2048-L16-B1024-nemo-wider-depth-adapt' - or 'isoflop-1e+19-d2048-L16-B1024-nemo-wider-depth-adapt-a1b2c3' - - Returns dict with: flops, d, L, B, experiment_name or None if parsing fails. - """ - # Strip optional - suffix - run_name = re.sub(r"-[0-9a-fA-F]{6}$", "", run_name) - - pattern = r"isoflop-([0-9.e+]+)-d(\d+)-L(\d+)-B(\d+)-(.+)" - match = re.match(pattern, run_name) - if not match: - return None - - flops_str, d, L, B, exp_name = match.groups() - return { - "flops": float(flops_str), - "d": int(d), - "L": int(L), - "B": int(B), - "experiment_name": exp_name, - } - - -def _robust_quad_logx(x: jnp.ndarray, y: jnp.ndarray, delta: float = 1.0) -> jnp.ndarray: - L = jnp.log10(x) - - def huber(residual): - abs_r = jnp.abs(residual) - quad = 0.5 * residual**2 - linear = delta * (abs_r - 0.5 * delta) - return jnp.where(abs_r <= delta, quad, linear) - - def objective(params): - a, b, c = params - pred = a * L**2 + b * L + c - residuals = y - pred - return jnp.sum(huber(residuals)) - - opt = ScipyMinimize(fun=objective, method="BFGS", value_and_grad=False) - init = jnp.array(jnp.polyfit(L, y, 2)) if len(L) >= 3 else jnp.array([0.0, *jnp.polyfit(L, y, 1)]) - return opt.run(init_params=init).params - - -def iso_plot_with_minima_df(df: pd.DataFrame): - """ - Expects df columns: tokens, loss, flops, params, name, label. - ISO plot: - - points: color by compute bucket (FLOPs), marker shape by dataset label - - dashed parabolas: per-(label, FLOPs) robust quadratic fits (restored) - - minima per (label, FLOPs): black diamonds - SCALING plot: - - one N* ~ A*C^alpha fit line per dataset (distinct color/dash) - - dataset minima as points in matching color - """ - if df is None or df.empty: - return go.Figure(), go.Figure() - - present = list(dict.fromkeys(df["label"].tolist())) - datasets = [lab for lab in CANON_LABELS if lab in present] + [lab for lab in present if lab not in CANON_LABELS] - - # Visual maps - buckets = sorted(df.flops.unique()) - bucket_color = {C: PALETTE[i % len(PALETTE)] for i, C in enumerate(buckets)} # ISO: color = compute bucket - ds_marker = {lab: MARKERS[i % len(MARKERS)] for i, lab in enumerate(datasets)} # ISO: shape = dataset - DS_COLORS = PALETTE - DASHES = ["solid", "dot", "dash", "longdash", "dashdot", "longdashdot"] - - fig_iso = go.Figure() - minima = [] # (label, C, N_star, loss) - - # ---- ISO: scatter, per-(label,C) parabola (RESTORED), and minima - for lab in datasets: - for C in buckets: - sub = df[(df.flops == C) & (df.label == lab)].sort_values("tokens") - if sub.empty: - continue - - # scatter - fig_iso.add_trace( - go.Scatter( - x=sub.tokens, - y=sub.loss, - mode="markers", - marker=dict(symbol=ds_marker[lab], color=bucket_color[C], size=8), - name=f"{lab}, {C:.2e} FLOPs", - legendgroup=f"{lab}, {C:.2e}", - hovertemplate=( - "C=%{text:.2e} FLOPs
tokens=%{x:.3e}
" - "loss=%{y:.4f}
params=%{customdata:.3e}" - ), - text=[C] * len(sub), - customdata=sub.params.values, - ) - ) - - # robust quadratic fit in log10(tokens) - a, b, c = _robust_quad_logx(jnp.array(sub.tokens.values), jnp.array(sub.loss.values)) - if a == 0: - continue - - # draw the parabola for this (lab, C) - Ls = jnp.linspace(jnp.log10(sub.tokens.min()), jnp.log10(sub.tokens.max()), 200) - fig_iso.add_trace( - go.Scatter( - x=10**Ls, - y=a * Ls**2 + b * Ls + c, - mode="lines", - line=dict(color=bucket_color[C], dash="dash", width=2), - showlegend=False, # avoid legend clutter - legendgroup=f"{lab}, {C:.2e}", - ) - ) - - # compute and draw minimum - L_opt = -b / (2 * a) - N_star = float(10**L_opt) - loss_opt = float(a * L_opt**2 + b * L_opt + c) - params_opt = sub.iloc[(sub.tokens - N_star).abs().argmin()].params - minima.append((lab, float(C), N_star, loss_opt)) - - fig_iso.add_trace( - go.Scatter( - x=[N_star], - y=[loss_opt], - mode="markers", - marker=_MIN_MARKER, - showlegend=False, - legendgroup=f"{lab}, {C:.2e}", - hovertemplate=( - "Compute-optimal
" - "C=%{text:.2e} FLOPs
tokens=%{x:.3e}
" - "loss=%{y:.4f}
params=%{customdata:.3e}" - ), - text=[C], - customdata=[params_opt], - ) - ) - - fig_iso.update_layout( - template="plotly_white", - xaxis_type="log", - xaxis_title="Tokens (log scale)", - yaxis_title="Bits Per Byte Validation", - title="Marin IsoFLOP Suite", - width=1000, - height=600, - ) - - # ---- SCALING: separate line per dataset - if not minima: - return fig_iso, go.Figure() - - fig_scale = go.Figure() - by_lab = {} - for lab, C, N_star, _ in minima: - by_lab.setdefault(lab, []).append((C, N_star)) - - for i, lab in enumerate(datasets): - pts = by_lab.get(lab, []) - if not pts: - continue - pts = sorted(pts) - Cs, Ns = zip(*pts, strict=False) - Cs = jnp.array(Cs) - Ns = jnp.array(Ns) - - color = DS_COLORS[i % len(DS_COLORS)] - dash = DASHES[i % len(DASHES)] - - # plot minima points - fig_scale.add_trace( - go.Scatter( - x=list(map(float, Cs)), - y=list(map(float, Ns)), - mode="markers", - marker=dict(symbol=_SCALE_MARKER["symbol"], size=_SCALE_MARKER["size"], color=color), - name=f"{lab} minima", - legendgroup=lab, - ) - ) - - if len(Cs) >= 2: - alpha, logA = jnp.polyfit(jnp.log10(Cs), jnp.log10(Ns), 1) - A = 10**logA - Cmin, Cmax = float(Cs.min()), float(Cs.max()) - C_fit = jnp.logspace(jnp.log10(Cmin) - 0.1, jnp.log10(Cmax) + 0.1, 400) - N_fit = A * (C_fit**alpha) - - fig_scale.add_trace( - go.Scatter( - x=list(map(float, C_fit)), - y=list(map(float, N_fit)), - mode="lines", - line=dict(color=color, dash=dash, width=_SCALE_LINE["width"]), - name=f"{lab} fit", - legendgroup=lab, - ) - ) - - fig_scale.update_layout( - template="plotly_white", - xaxis_type="log", - yaxis_type="log", - xaxis_title="Compute budget C (FLOPs, log)", - yaxis_title="Optimal tokens N* (log)", - title="Scaling fits per dataset", - ) - - return fig_iso, fig_scale - - -# ---------------- Executor Integration ---------------- -@dataclass(frozen=True) -class IsoFlopAnalysisConfig(EvalMetricsAnalysisConfig): - """Config for isoflop analysis - extends base eval metrics analysis. - - Inherits training_runs, output_path, and backfill settings from base. - Adds isoflop-specific parameters. - """ - - metric_key: str = DEFAULT_METRIC_KEY - """Which metric to use for loss (default: eval/paloma/c4_en/bpb).""" - - label_map: dict[str, str] | None = None - """Optional mapping from experiment_name -> display label.""" - - upload_to_wandb: bool = True - """Whether to upload plots to WandB.""" - - wandb_entity: str = "marin-community" - wandb_project: str = "marin-analysis" - wandb_run_name: str = "isoflop-analysis" - - -def _transform_metrics_for_isoflop( - df: pd.DataFrame, - metric_key: str, - label_map: dict[str, str] | None, -) -> pd.DataFrame: - """Transform raw metrics DataFrame into isoflop plotting format. - - Takes the generic metrics DataFrame from read_metrics_dataframe() and - transforms it into the format expected by iso_plot_with_minima_df(): - columns: tokens, loss, flops, params, name, label - """ - if df.empty: - return pd.DataFrame(columns=["tokens", "loss", "flops", "params", "name", "label"]) - - # Get final metrics for each run (max step) - final_metrics = df.loc[df.groupby("run_path")["step"].idxmax()].copy() - - records = [] - for _, row in final_metrics.iterrows(): - run_path = row["run_path"] - run_name = extract_run_name_from_path(run_path) - - # Parse metadata from run name - meta = _parse_isoflop_run_name(run_name) - if meta is None: - print(f"Warning: Could not parse metadata from run name: {run_name}") - continue - - flops = meta["flops"] - if flops < 1e18: - continue - - # Calculate tokens = steps * batch * seq_len - steps = row["step"] - batch = meta["B"] - tokens = steps * batch * SEQ_LEN - - # Get loss from the metric column - loss = row.get(metric_key) - if loss is None or pd.isna(loss): - print(f"Warning: Missing metric {metric_key} for run {run_name}") - continue - - params = row.get("parameter_count") - if params is None or pd.isna(params): - params = None - - # Determine label - exp_name = meta["experiment_name"] - if label_map and exp_name in label_map: - label = label_map[exp_name] - else: - label = exp_name - for canon in CANON_LABELS: - if canon in exp_name.lower(): - label = canon - break - - records.append( - dict( - tokens=tokens, - loss=loss, - flops=flops, - params=params, - name=run_name, - label=label, - ) - ) - - return pd.DataFrame.from_records(records) - - -def run_isoflop_analysis(config: IsoFlopAnalysisConfig) -> None: - """Run isoflop analysis from training runs. - - This is a subtype of eval metrics analysis that: - 1. Reads metrics using the base read_metrics_dataframe() - 2. Transforms them for isoflop plotting - 3. Generates and saves isoflop/scaling plots - """ - # Use inherited metrics reading from base - raw_df = read_metrics_dataframe(config) - - if raw_df.empty: - print("Warning: No eval metrics found") - return - - # Transform to isoflop format - df = _transform_metrics_for_isoflop(raw_df, config.metric_key, config.label_map) - - if df.empty: - print("Warning: No valid isoflop data after transformation") - return - - print(f"Transformed {len(df)} runs for isoflop analysis") - fig_iso, fig_scaling = iso_plot_with_minima_df(df) - - # Save plots locally - fs, _, _ = fsspec.get_fs_token_paths(config.output_path) - fs.makedirs(config.output_path, exist_ok=True) - - iso_path = os.path.join(config.output_path, "isoflop_plot.html") - scaling_path = os.path.join(config.output_path, "scaling_plot.html") - - with fs.open(iso_path, "w") as f: - f.write(fig_iso.to_html()) - print(f"Wrote isoflop plot to {iso_path}") - - with fs.open(scaling_path, "w") as f: - f.write(fig_scaling.to_html()) - print(f"Wrote scaling plot to {scaling_path}") - - # Optionally upload to WandB - if config.upload_to_wandb and WANDB_AVAILABLE: - wandb.login() - run = wandb.init( - entity=config.wandb_entity, - project=config.wandb_project, - job_type="isoflop-analysis", - name=config.wandb_run_name, - resume="allow", - ) - wandb.log( - { - "isoFLOP_plot": wandb.Plotly(fig_iso), - "scaling_plot": wandb.Plotly(fig_scaling), - } - ) - run.finish() - print("Uploaded plots to WandB") - - -def create_isoflop_analysis_step( - name: str, - training_runs: Sequence[ExecutorStep | InputName], - metric_key: str = DEFAULT_METRIC_KEY, - label_map: dict[str, str] | None = None, - upload_to_wandb: bool = True, - description: str | None = None, -) -> ExecutorStep: - """Create an ExecutorStep for isoflop analysis. - - This uses the base create_analysis_step() with IsoFlopAnalysisConfig. - - Args: - name: Name for this executor step - training_runs: Training run ExecutorSteps (creates blocking dependencies) - metric_key: Which metric to use for loss - label_map: Optional mapping from experiment_name -> display label - upload_to_wandb: Whether to upload plots to WandB - description: Optional description - - Returns: - ExecutorStep configured to run isoflop analysis - """ - return create_analysis_step( - name=name, - training_runs=training_runs, - analysis_fn=run_isoflop_analysis, - config_class=IsoFlopAnalysisConfig, - description=description or f"IsoFLOP analysis for {len(training_runs)} runs", - metric_key=metric_key, - label_map=label_map, - upload_to_wandb=upload_to_wandb, - ) - - -# ---------------- Main (using experiments/isoflop_sweep.py) ---------------- -def main_from_isoflop_sweep( - suite_names: list[str] | None = None, - metric_key: str = DEFAULT_METRIC_KEY, - upload_to_wandb: bool = True, -): - """ - Run isoflop analysis using training runs from experiments/isoflop_sweep.py. - - Args: - suite_names: Names of scaling suites from MARIN_SCALING_SUITES (default: all) - metric_key: Which metric to use for loss - upload_to_wandb: Whether to upload plots to WandB - """ - if not ISOFLOP_SWEEP_AVAILABLE: - raise RuntimeError( - "Cannot import from experiments.isoflop_sweep. " "Make sure the experiments module is in your Python path." - ) - - if suite_names is None: - suite_names = list(MARIN_SCALING_SUITES.keys()) - - # Collect all training runs from the specified suites - all_training_runs = [] - label_map = {} - - for suite_name in suite_names: - if suite_name not in MARIN_SCALING_SUITES: - logger.warning(f"Suite '{suite_name}' not found in MARIN_SCALING_SUITES") - continue - - steps, _ = MARIN_SCALING_SUITES[suite_name] - # Filter to just training steps (not eval steps) - training_steps = [step for step in steps if step.name.startswith("isoflop-")] - all_training_runs.extend(training_steps) - - # Build label map from experiment names - for step in training_steps: - meta = _parse_isoflop_run_name(step.name) - if meta: - exp_name = meta["experiment_name"] - # Map experiment name to canonical label - for canon in CANON_LABELS: - if canon in exp_name.lower(): - label_map[exp_name] = canon - break - - if not all_training_runs: - logger.error("No training runs found in specified suites") - return - - logger.info(f"Found {len(all_training_runs)} training runs across {len(suite_names)} suites") - - # Create and run analysis - config = IsoFlopAnalysisConfig( - training_runs=[output_path_of(step) for step in all_training_runs], - output_path="analysis/isoflop", - metric_key=metric_key, - label_map=label_map, - upload_to_wandb=upload_to_wandb, - ) - - run_isoflop_analysis(config) - - -if __name__ == "__main__": - # Use the new logic that imports from isoflop_sweep.py - main_from_isoflop_sweep( - suite_names=["nemotron", "common_pile", "dclm-default"], - upload_to_wandb=True, - ) diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py new file mode 100644 index 0000000000..9a0a669427 --- /dev/null +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -0,0 +1,411 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Scaling ladder: compute-optimal training runs based on IsoFLOP analysis. + +This module provides ExecutorSteps for training models with compute-optimal +configurations derived from IsoFLOP analysis. + +Usage: + from marin.scaling_laws import isoflop_analysis_step + from marin.scaling_laws.scaling_ladder import scaling_ladder_rung_step + + # First, run IsoFLOP analysis + analysis = isoflop_analysis_step( + name="scaling-analysis", + training_runs=isoflop_training_steps, + ) + + # Then create optimal training steps (ladder rungs) that depend on the analysis + rung_1e21 = scaling_ladder_rung_step( + name="optimal-1e21", + analysis_step=analysis, + target_budget=1e21, + label="nemo", + dataset=my_tokenized_dataset, + ) +""" + +import json +import logging +import os +from collections.abc import Sequence +from dataclasses import dataclass + +import fsspec +from levanter.data.text import LMMixtureDatasetConfig +from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig +from levanter.models.qwen import Qwen3Config + +from marin.execution.executor import ExecutorStep, InputName, output_path_of, this_output_path +from marin.scaling_laws.isoflop_analysis import ( + CandidateConfig, + IsoFlopSweepConfig, + predict_optimal_config, +) + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class ScalingLadderRungConfig: + """Configuration for one rung of the scaling ladder (one compute-optimal training run). + + This config references an IsoFLOP analysis step and specifies + the target compute budget. At runtime, the optimal config is loaded + from the analysis output. + """ + + analysis_output_path: str + """Path to the IsoFLOP analysis output directory.""" + + target_budget: float + """Target compute budget in FLOPs.""" + + label: str + """Dataset label to use for scaling fit (e.g., 'nemo', 'comma', 'dclm').""" + + dataset: str + """Path to tokenized dataset for training.""" + + output_path: str + """Where to write training outputs.""" + + tokenizer: str = "stanford-crfm/marin-tokenizer" + """Tokenizer to use.""" + + seq_len: int = 4096 + """Sequence length for training.""" + + sweep_config: IsoFlopSweepConfig | None = None + """Optional sweep config for predict_optimal_config. Uses defaults if None.""" + + +def load_scaling_fits(analysis_path: str) -> dict[str, tuple[float, float]]: + """Load scaling fits from an IsoFLOP analysis output.""" + result_path = os.path.join(analysis_path, "isoflop_analysis_result.json") + fs, _, _ = fsspec.get_fs_token_paths(result_path) + + with fs.open(result_path, "r") as f: + result = json.load(f) + + # Convert lists back to tuples + return {k: tuple(v) for k, v in result["scaling_fits"].items()} + + +def get_optimal_candidate(config: ScalingLadderRungConfig) -> CandidateConfig: + """Load scaling fits and predict optimal config for target budget.""" + from experiments.metrics.wandb_related import get_vocab_size_for_tokenizer + + scaling_fits = load_scaling_fits(config.analysis_output_path) + vocab_size = get_vocab_size_for_tokenizer(config.tokenizer) + + candidate = predict_optimal_config( + scaling_fits=scaling_fits, + target_flops=config.target_budget, + label=config.label, + sweep_config=config.sweep_config, + vocab_size=vocab_size, + ) + + if candidate is None: + raise RuntimeError( + f"Could not find optimal config for budget {config.target_budget:.2e} " + f"and label '{config.label}'" + ) + + return candidate + + +def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: + """Run one rung of the scaling ladder (one compute-optimal training run). + + This function: + 1. Loads scaling fits from the analysis output + 2. Predicts the optimal config for the target budget + 3. Trains a model with that config using the same infrastructure as default_train + """ + from datetime import timedelta + + import jmp + from fray.cluster import ResourceConfig + from levanter.checkpoint import CheckpointerConfig + from levanter.main.train_lm import TrainLmConfig + from levanter.optim.cautious import CautiousConfig + from levanter.tracker.wandb import WandbConfig + from levanter.trainer import TrainerConfig + + from experiments.defaults import _prepare_data_config + from experiments.llama import compute_num_parameters + from experiments.metrics.wandb_related import get_vocab_size_for_tokenizer + from marin.processing.tokenize import lm_mixture_data_config + from marin.scaling_laws.isoflop_analysis import pick_v5p_type + from marin.training.training import TrainLmOnPodConfig, run_levanter_train_lm + + # Get the optimal candidate config from analysis + candidate = get_optimal_candidate(config) + vocab_size = get_vocab_size_for_tokenizer(config.tokenizer) + + logger.info( + f"Training with optimal config for {config.target_budget:.2e} FLOPs:\n" + f" hidden_size={candidate.hidden_size}, num_layers={candidate.num_layers}\n" + f" batch_size={candidate.batch_size}, train_steps={candidate.train_steps}\n" + f" learning_rate={candidate.learning_rate:.6f}, tokens={candidate.tokens:.2e}" + ) + + # Build model config + model_cfg = Qwen3Config( + max_seq_len=config.seq_len, + hidden_dim=candidate.hidden_size, + intermediate_dim=candidate.intermediate_dim, + num_heads=candidate.num_heads, + num_kv_heads=candidate.num_kv_heads, + num_layers=candidate.num_layers, + rope=Llama3RotaryEmbeddingsConfig(), + ) + + # Pick TPU type based on memory requirements + param_count = compute_num_parameters(model_cfg, vocab_size) + tpu_type = pick_v5p_type( + param_count, + candidate.hidden_size, + candidate.num_layers, + candidate.batch_size, + config.seq_len, + vocab_size, + ) + + # Build optimizer config (matches isoflop_sweep defaults) + optimizer_cfg = CautiousConfig( + learning_rate=candidate.learning_rate, + weight_decay=0.1, + min_lr_ratio=0.0, + warmup=0.1, + beta1=0.95, + beta2=candidate.beta2, + epsilon=1e-15, + max_grad_norm=1, + adamc_weight_decay=True, + lr_schedule="linear", + decay=0.2, + ) + + # Prepare data config (uses same helper as default_train) + data_config = lm_mixture_data_config( + components={"train": config.dataset}, + weights={"train": 1.0}, + ) + pretraining_data = _prepare_data_config(data_config, use_default_validation=True) + + # Build TrainLmConfig (mirrors default_train structure) + train_config = TrainLmConfig( + data=pretraining_data, + trainer=TrainerConfig( + tracker=WandbConfig( + project="marin", + tags=[ + "optimal-training", + f"FLOPs={config.target_budget:.1e}", + f"label={config.label}", + ], + ), + mp=jmp.get_policy("p=f32,c=bfloat16"), + train_batch_size=candidate.batch_size, + num_train_steps=candidate.train_steps, + steps_per_eval=1000, + checkpointer=CheckpointerConfig( + save_interval=timedelta(minutes=10), + keep=[dict(every=5000)], + ), + replica_dcn_axis_size=-1, + allow_nondivisible_batch_size=True, + ), + train_seq_len=config.seq_len, + model=model_cfg, + optimizer=optimizer_cfg, + ) + + # Build pod config and run training + full_config = TrainLmOnPodConfig( + train_config=train_config, + resources=ResourceConfig.with_tpu(tpu_type), + output_path=config.output_path, + ) + + run_levanter_train_lm(full_config) + + +def scaling_ladder_rung_step( + name: str, + analysis_step: ExecutorStep, + target_budget: float, + label: str, + dataset: InputName | ExecutorStep | LMMixtureDatasetConfig, + tokenizer: str = "stanford-crfm/marin-tokenizer", + seq_len: int = 4096, +) -> ExecutorStep: + """Create an ExecutorStep for one rung of the scaling ladder. + + This step depends on an IsoFLOP analysis step and will train a model + using the optimal configuration predicted from the scaling fits. + + Args: + name: Name for this executor step + analysis_step: The IsoFLOP analysis step to read fits from + target_budget: Target compute budget in FLOPs + label: Dataset label to use for scaling fit (e.g., 'nemo', 'comma') + dataset: Tokenized dataset to train on + tokenizer: Tokenizer to use + seq_len: Sequence length for training + + Returns: + ExecutorStep configured to run one optimal training run + """ + # Resolve dataset path + if isinstance(dataset, ExecutorStep): + dataset_path = output_path_of(dataset) + elif isinstance(dataset, LMMixtureDatasetConfig): + # For mixture configs, we'll need to handle this differently + # For now, just serialize it somehow - this is a limitation + raise NotImplementedError("LMMixtureDatasetConfig not yet supported for scaling_ladder_rung_step") + else: + dataset_path = dataset + + config = ScalingLadderRungConfig( + analysis_output_path=output_path_of(analysis_step), + target_budget=target_budget, + label=label, + dataset=dataset_path, + output_path=this_output_path(), + tokenizer=tokenizer, + seq_len=seq_len, + ) + + return ExecutorStep( + name=os.path.join("checkpoints", name), + fn=run_scaling_ladder_rung, + config=config, + description=f"Scaling ladder rung: optimal training for {target_budget:.1e} FLOPs based on IsoFLOP analysis", + pip_dependency_groups=["tokenize_train"], + ) + + +# ---------------- Scaling Ladder Suite ---------------- + + +@dataclass +class ScalingLadderSuite: + """A suite containing IsoFLOP analysis and scaling ladder rungs (optimal training steps). + + This is returned by `scaling_ladder_suite()` and contains all the steps + needed for end-to-end scaling ladder: IsoFLOP analysis + optimal training runs. + """ + + analysis: ExecutorStep + """The IsoFLOP analysis step.""" + + optimal_runs: list[ExecutorStep] + """Scaling ladder rungs: training steps for each target budget, using predicted optimal configs.""" + + @property + def all_steps(self) -> list[ExecutorStep]: + """All steps in the suite (analysis + optimal runs).""" + return [self.analysis, *self.optimal_runs] + + +def scaling_ladder_suite( + name: str, + training_runs: Sequence[ExecutorStep | InputName], + target_budgets: Sequence[float], + label: str, + dataset: InputName | ExecutorStep, + tokenizer: str = "stanford-crfm/marin-tokenizer", + seq_len: int = 4096, + metric_key: str = "eval/paloma/c4_en/bpb", + label_map: dict[str, str] | None = None, + save_plots: bool = True, + upload_to_wandb: bool = True, + wandb_entity: str = "marin-community", + wandb_project: str = "marin-analysis", +) -> ScalingLadderSuite: + """Create a complete scaling ladder: IsoFLOP analysis + optimal training runs. + + This is the full pipeline interface that creates: + 1. An IsoFLOP analysis step that fits scaling laws + 2. Scaling ladder rungs (optimal training steps) for each target budget + + The optimal training steps depend on the analysis step and will train + models using compute-optimal configurations predicted from the scaling fits. + + Args: + name: Base name for the steps + training_runs: IsoFLOP training run ExecutorSteps to analyze + target_budgets: Target compute budgets (in FLOPs) for optimal training + label: Dataset label to use for scaling fit (e.g., 'nemo', 'comma') + dataset: Tokenized dataset for optimal training runs + tokenizer: Tokenizer to use + seq_len: Sequence length for training + metric_key: Which metric to use for loss + label_map: Optional mapping from experiment_name -> display label + save_plots: Whether to save HTML plots + upload_to_wandb: Whether to upload plots to WandB + wandb_entity: WandB entity for uploads + wandb_project: WandB project for uploads + + Returns: + ScalingLadderSuite containing the analysis step and optimal training steps + + Example: + >>> suite = scaling_ladder_suite( + ... name="nemo-scaling", + ... training_runs=isoflop_training_steps, + ... target_budgets=[1e21, 3e21, 1e22], + ... label="nemo", + ... dataset=nemotron_tokenized, + ... ) + >>> all_steps = [*isoflop_training_steps, *suite.all_steps] + """ + from marin.scaling_laws.isoflop_analysis import isoflop_analysis_step + + # Create the IsoFLOP analysis step + analysis = isoflop_analysis_step( + name=f"{name}-analysis", + training_runs=training_runs, + metric_key=metric_key, + label_map=label_map, + save_plots=save_plots, + upload_to_wandb=upload_to_wandb, + wandb_entity=wandb_entity, + wandb_project=wandb_project, + wandb_run_name=f"{name}-analysis", + ) + + # Create scaling ladder rungs (optimal training steps) for each target budget + optimal_runs = [] + for budget in target_budgets: + run_step = scaling_ladder_rung_step( + name=f"{name}-optimal-{budget:.0e}", + analysis_step=analysis, + target_budget=budget, + label=label, + dataset=dataset, + tokenizer=tokenizer, + seq_len=seq_len, + ) + optimal_runs.append(run_step) + + return ScalingLadderSuite( + analysis=analysis, + optimal_runs=optimal_runs, + ) diff --git a/lib/marin/src/marin/scaling_laws/scaling_laws.py b/lib/marin/src/marin/scaling_laws/scaling_laws.py deleted file mode 100644 index 7d602ac580..0000000000 --- a/lib/marin/src/marin/scaling_laws/scaling_laws.py +++ /dev/null @@ -1,218 +0,0 @@ -# Copyright 2025 The Marin Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file contains functions for setting scaling law configurations, wrapper functions to call the relevant -regressions/predictions, and creating a WandB report with the results. The code here implements the function -(see `run_scaling_law_analysis`, that will be called by an ExecutorStep in the scaling laws analysis pipeline. - -Our objective is to predict the accuracy of a larger target model on a specific benchmark. -This prediction is done through a two-step modeling process using (N, D) data from various smaller models: -- we first fit a power-law model to predict the task loss from the number of parameters and tokens. -- then, we fit a sigmoidal model to predict task accuracy from the task loss. - -Reference: - Establishing Task Scaling Laws via Compute-Efficient Model Ladders - Bhagia et. al 2024 - https://arxiv.org/pdf/2412.04403. -""" - -from collections.abc import Sequence -from dataclasses import dataclass, field - -import numpy as np -import wandb - -from marin.execution.executor import ExecutorStep -from marin.scaling_laws.utils import ( - ProjectionPoint, - get_default_projection_points, - plot_actual_vs_predicted, - plot_scaling_projections, -) -from marin.utilities.wandb_utils import WANDB_ENTITY, WANDB_PROJECT - - -@dataclass(frozen=True) -class ScalingLawConfig: - name: str - """name of the scaling law analysis or config (used for the report name)""" - - ladder_model_steps: Sequence[ExecutorStep | str] - """list of (smaller model) steps or wandb run ids to be used as input for scaling laws""" - - pred_model_step: ExecutorStep | str - """executor step or wandb run id for the larger model to make predictions for""" - - projection_points: list[ProjectionPoint] | None = None - """Points to project to, consisting of number of parameters and tokens""" - - task_losses: Sequence[str] = field(default_factory=lambda: ["eval/paloma/c4_en/bpb"]) - """task losses to predict for scaling laws (eg. c4en bpb)""" - - task_accuracies: Sequence[str] | None = None - """task accuracy to predict for the larger model (eg. hellaswag accuracy)""" - - use_log_for_ND: bool = True - """whether to use log space for N,D in scaling laws""" - - normalize_ND: bool = True - """whether to normalize N,D in scaling laws""" - - count_embedding_params: bool = False - """whether to count embedding parameters in scaling laws""" - - entity: str = WANDB_ENTITY - project: str = "marin" - - def __post_init__(self): - # Set default projection points if none provided - if self.projection_points is None: - object.__setattr__( - self, - "projection_points", - get_default_projection_points(count_embedding_params=self.count_embedding_params), - ) - - -def get_wandb_run_id_from_step(step: ExecutorStep) -> str: - """ - Get the wandb run id from a given ExecutorStep. - """ - return step.config.trainer.tracker.id - - -def run_scaling_law_analysis(config: ScalingLawConfig) -> None: - """ - Analyze scaling laws for a given task loss and multiple accuracy metrics. - """ - from marin.scaling_laws.utils import fit_scaling_laws - - input_run_ids = [ - get_wandb_run_id_from_step(step) if isinstance(step, ExecutorStep) else step - for step in config.ladder_model_steps - ] - - pred_run_id = None - if config.pred_model_step: - pred_run_id = ( - get_wandb_run_id_from_step(config.pred_model_step) - if isinstance(config.pred_model_step, ExecutorStep) - else config.pred_model_step - ) - - projections, predictions = fit_scaling_laws( - runs=input_run_ids, - loss_metrics=config.task_losses, - accuracy_metrics=config.task_accuracies, - entity=config.entity, - project=config.project, - pred_run=pred_run_id, - projection_points=config.projection_points, - count_embedding_params=config.count_embedding_params, - use_log_for_ND=config.use_log_for_ND, - normalize_ND=config.normalize_ND, - ) - - log_and_create_report( - projections=projections, - points=config.projection_points, - predictions=predictions, - input_run_ids=input_run_ids, - pred_run_id=pred_run_id, - scaling_law_config=config, - ) - - -def log_and_create_report( - projections: dict[str, np.ndarray], - points: list[ProjectionPoint] | None, - predictions: tuple[dict, dict, np.ndarray, np.ndarray] | None, - input_run_ids: list, - pred_run_id: str | None, - scaling_law_config: ScalingLawConfig, - wandb_project: str = "marin-scaling-laws", - wandb_entity: str = WANDB_ENTITY, - wandb_source_project: str = WANDB_PROJECT, -): - """ - Logs scaling law analysis creates a concise WandB report with plots and info about runs. - """ - # Initialize WandB run - run = wandb.init( - project=wandb_project, - entity=wandb_entity, - name=f"""Scaling Law Report: {pred_run_id if pred_run_id else 'projection'}-{scaling_law_config.name}""", - tags=["scaling_laws"], - config={ - "input_runs": input_run_ids, - "prediction_run": pred_run_id, - }, - reinit=True, - ) - - plots = {} - - # Log projections - if points: - for loss_name, projection in projections.items(): - figure = plot_scaling_projections(projection, points) - plots[f"Projection - {loss_name}"] = wandb.Image(figure) - - # Log predictions if available - if predictions: - loss_results, accuracy_results, loss_tokens, acc_tokens = predictions - - if loss_results: - for loss_name, (actual_loss, predicted_loss) in loss_results.items(): - figure = plot_actual_vs_predicted( - actual_loss.tolist(), - predicted_loss.tolist(), - title=f"Actual vs Predicted {loss_name}", - task_metric=loss_name, - tokens=loss_tokens, - ) - plots[f"Task Loss - {loss_name}"] = wandb.Image(figure) - - if accuracy_results: - for metric, (actual_acc, predicted_acc) in accuracy_results.items(): - figure = plot_actual_vs_predicted( - actual_acc.tolist(), - predicted_acc.tolist(), - title=f"Actual vs Predicted {metric}", - task_metric=metric, - tokens=acc_tokens, - ) - plots[f"Task Accuracy - {metric}"] = wandb.Image(figure) - - # Log all plots - wandb.log(plots) - - # Info about runs and links - input_run_links = [ - f"https://wandb.ai/{wandb_entity}/{wandb_source_project}/runs/{run_id}" for run_id in input_run_ids - ] - prediction_run_link = ( - f"https://wandb.ai/{wandb_entity}/{wandb_source_project}/runs/{pred_run_id}" if pred_run_id else None - ) - run.summary.update( - { - "Input Runs": input_run_links, - "Prediction Run": prediction_run_link, - "Task Losses": scaling_law_config.task_losses, - "Task Accuracies": scaling_law_config.task_accuracies, - } - ) - - wandb.finish() diff --git a/lib/marin/src/marin/scaling_laws/scaling_plots.py b/lib/marin/src/marin/scaling_laws/scaling_plots.py new file mode 100644 index 0000000000..bf49f312ae --- /dev/null +++ b/lib/marin/src/marin/scaling_laws/scaling_plots.py @@ -0,0 +1,336 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Visualization functions for scaling ladder analysis. + +This module provides plotting utilities for isoflop analysis results. +All plotly-related code is contained here to keep the core scaling_ladder +module free of visualization dependencies. +""" + +import logging +import os + +import fsspec +import jax.numpy as jnp +import pandas as pd +import plotly.graph_objects as go +import plotly.io as pio + +try: + import wandb + + WANDB_AVAILABLE = True +except ImportError: + WANDB_AVAILABLE = False + + +logger = logging.getLogger(__name__) + +# ---------------- Theme ---------------- +pio.templates.default = "plotly_white" + +# ---------------- Visual Constants ---------------- +PALETTE = [ + "#1877F2", + "#F0701A", + "#5A24C7", + "#E42C97", + "#00487C", + "#0EAC96", + "#AB76FF", + "#B50550", + "#0099E6", + "#22085F", + "#783301", +] + +MARKERS = [ + "circle", + "square", + "cross", + "x", + "triangle-up", + "triangle-down", + "triangle-left", + "triangle-right", + "pentagon", + "hexagon", + "hexagon2", + "star", + "star-triangle-up", + "star-triangle-down", + "star-square", + "star-diamond", + "hourglass", + "bowtie", +] + +DASHES = ["solid", "dot", "dash", "longdash", "dashdot", "longdashdot"] + +_MIN_MARKER = dict(symbol="diamond", size=10, color="#000000") +_SCALE_MARKER = dict(symbol="circle", size=9, color=PALETTE[0]) +_SCALE_LINE = dict(dash="dot", width=2, color=PALETTE[0]) + +CANON_LABELS = ["nemo", "comma", "dclm"] + + +def create_isoflop_plot( + df: pd.DataFrame, + minima_records: list[dict], + fit_curves: dict[tuple[str, float], tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]], +) -> go.Figure: + """Create the IsoFLOP plot showing loss vs tokens for each compute budget. + + Args: + df: DataFrame with columns: tokens, loss, flops, params, name, label + minima_records: List of dicts with optimal config info per (label, flops) + fit_curves: Dict of {(label, flops): (a, b, c)} quadratic fit coefficients + + Returns: + Plotly Figure with the isoflop visualization + """ + if df.empty: + return go.Figure() + + present = list(dict.fromkeys(df["label"].tolist())) + datasets = [lab for lab in CANON_LABELS if lab in present] + [lab for lab in present if lab not in CANON_LABELS] + + buckets = sorted(df.flops.unique()) + bucket_color = {C: PALETTE[i % len(PALETTE)] for i, C in enumerate(buckets)} + ds_marker = {lab: MARKERS[i % len(MARKERS)] for i, lab in enumerate(datasets)} + + fig = go.Figure() + + # Build lookup for minima + minima_lookup = {(rec["label"], rec["flops"]): rec for rec in minima_records} + + for lab in datasets: + for C in buckets: + sub = df[(df.flops == C) & (df.label == lab)].sort_values("tokens") + if sub.empty: + continue + + # Scatter points + fig.add_trace( + go.Scatter( + x=sub.tokens, + y=sub.loss, + mode="markers", + marker=dict(symbol=ds_marker[lab], color=bucket_color[C], size=8), + name=f"{lab}, {C:.2e} FLOPs", + legendgroup=f"{lab}, {C:.2e}", + hovertemplate=( + "C=%{text:.2e} FLOPs
tokens=%{x:.3e}
" + "loss=%{y:.4f}
params=%{customdata:.3e}" + ), + text=[C] * len(sub), + customdata=sub.params.values, + ) + ) + + # Draw fit curve if available + key = (lab, C) + if key in fit_curves: + a, b, c = fit_curves[key] + if a != 0: + Ls = jnp.linspace(jnp.log10(sub.tokens.min()), jnp.log10(sub.tokens.max()), 200) + fig.add_trace( + go.Scatter( + x=10**Ls, + y=a * Ls**2 + b * Ls + c, + mode="lines", + line=dict(color=bucket_color[C], dash="dash", width=2), + showlegend=False, + legendgroup=f"{lab}, {C:.2e}", + ) + ) + + # Draw minimum marker + if key in minima_lookup: + rec = minima_lookup[key] + fig.add_trace( + go.Scatter( + x=[rec["optimal_tokens"]], + y=[rec["loss_at_optimal"]], + mode="markers", + marker=_MIN_MARKER, + showlegend=False, + legendgroup=f"{lab}, {C:.2e}", + hovertemplate=( + "Compute-optimal
" + "C=%{text:.2e} FLOPs
tokens=%{x:.3e}
" + "loss=%{y:.4f}
params=%{customdata:.3e}" + ), + text=[C], + customdata=[rec["optimal_params"]], + ) + ) + + fig.update_layout( + template="plotly_white", + xaxis_type="log", + xaxis_title="Tokens (log scale)", + yaxis_title="Bits Per Byte Validation", + title="Marin IsoFLOP Suite", + width=1000, + height=600, + ) + + return fig + + +def create_scaling_plot( + minima_records: list[dict], + scaling_fits: dict[str, tuple[float, float]], +) -> go.Figure: + """Create the scaling law fit plot showing N* vs compute budget. + + Args: + minima_records: List of dicts with optimal config info per (label, flops) + scaling_fits: Dict of {label: (alpha, A)} for N* ~ A * C^alpha + + Returns: + Plotly Figure with the scaling fit visualization + """ + if not minima_records: + return go.Figure() + + # Group by label + by_lab = {} + for rec in minima_records: + by_lab.setdefault(rec["label"], []).append(rec) + + present = list(by_lab.keys()) + datasets = [lab for lab in CANON_LABELS if lab in present] + [lab for lab in present if lab not in CANON_LABELS] + + fig = go.Figure() + + for i, lab in enumerate(datasets): + recs = by_lab.get(lab, []) + if not recs: + continue + + recs = sorted(recs, key=lambda r: r["flops"]) + Cs = jnp.array([r["flops"] for r in recs]) + Ns = jnp.array([r["optimal_tokens"] for r in recs]) + + color = PALETTE[i % len(PALETTE)] + dash = DASHES[i % len(DASHES)] + + # Plot minima points + fig.add_trace( + go.Scatter( + x=list(map(float, Cs)), + y=list(map(float, Ns)), + mode="markers", + marker=dict(symbol=_SCALE_MARKER["symbol"], size=_SCALE_MARKER["size"], color=color), + name=f"{lab} minima", + legendgroup=lab, + ) + ) + + # Plot fit line if available + if lab in scaling_fits: + alpha, A = scaling_fits[lab] + Cmin, Cmax = float(Cs.min()), float(Cs.max()) + C_fit = jnp.logspace(jnp.log10(Cmin) - 0.1, jnp.log10(Cmax) + 0.1, 400) + N_fit = A * (C_fit**alpha) + + fig.add_trace( + go.Scatter( + x=list(map(float, C_fit)), + y=list(map(float, N_fit)), + mode="lines", + line=dict(color=color, dash=dash, width=_SCALE_LINE["width"]), + name=f"{lab} fit (α={alpha:.3f})", + legendgroup=lab, + ) + ) + + fig.update_layout( + template="plotly_white", + xaxis_type="log", + yaxis_type="log", + xaxis_title="Compute budget C (FLOPs, log)", + yaxis_title="Optimal tokens N* (log)", + title="Scaling fits per dataset", + ) + + return fig + + +def save_plots( + fig_isoflop: go.Figure, + fig_scaling: go.Figure, + output_path: str, +) -> None: + """Save isoflop and scaling plots to HTML files. + + Args: + fig_isoflop: IsoFLOP plot figure + fig_scaling: Scaling fit plot figure + output_path: Directory path to save plots + """ + fs, _, _ = fsspec.get_fs_token_paths(output_path) + fs.makedirs(output_path, exist_ok=True) + + iso_path = os.path.join(output_path, "isoflop_plot.html") + scaling_path = os.path.join(output_path, "scaling_plot.html") + + with fs.open(iso_path, "w") as f: + f.write(fig_isoflop.to_html()) + logger.info(f"Wrote isoflop plot to {iso_path}") + + with fs.open(scaling_path, "w") as f: + f.write(fig_scaling.to_html()) + logger.info(f"Wrote scaling plot to {scaling_path}") + + +def upload_plots_to_wandb( + fig_isoflop: go.Figure, + fig_scaling: go.Figure, + entity: str = "marin-community", + project: str = "marin-analysis", + run_name: str = "scaling-ladder-analysis", +) -> None: + """Upload plots to Weights & Biases. + + Args: + fig_isoflop: IsoFLOP plot figure + fig_scaling: Scaling fit plot figure + entity: WandB entity + project: WandB project + run_name: Name for the WandB run + """ + if not WANDB_AVAILABLE: + logger.warning("wandb not available, cannot upload plots") + return + + wandb.login() + run = wandb.init( + entity=entity, + project=project, + job_type="scaling-ladder", + name=run_name, + resume="allow", + ) + wandb.log( + { + "isoFLOP_plot": wandb.Plotly(fig_isoflop), + "scaling_plot": wandb.Plotly(fig_scaling), + } + ) + run.finish() + logger.info("Uploaded plots to WandB") diff --git a/lib/marin/src/marin/scaling_laws/utils.py b/lib/marin/src/marin/scaling_laws/utils.py deleted file mode 100644 index ed669f7300..0000000000 --- a/lib/marin/src/marin/scaling_laws/utils.py +++ /dev/null @@ -1,667 +0,0 @@ -# Copyright 2025 The Marin Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Functions for fitting scaling laws and plotting the results. - -The functions in this file implement the techniques in https://arxiv.org/pdf/2412.04403. - -Our objective is to predict the accuracy of a larger target model on a specific benchmark. -This prediction is done through a two-step modeling process using (N, D) data from various smaller models: -- we first fit a power-law model to predict the task loss from the number of parameters and tokens. -- then, we fit a sigmoidal model to predict the task accuracy from the task loss. - -For further details see the corresponding GitHub issue: https://github.com/marin-community/marin/issues/646. - -To use this code, call fit_scaling_laws() with appropriate arguments. -""" - -from collections.abc import Callable, Sequence -from dataclasses import dataclass -from typing import Any - -import matplotlib.pyplot as plt -import numpy as np -from scipy.optimize import curve_fit, minimize -from scipy.special import huber - -from experiments.llama import compute_num_parameters, llama3_tokenizer_vocab_size - -try: - import pandas as pd -except ImportError: - pd: Any = None - -OPTIMIZATION_TOLERANCE = 1e-10 - -#################################################################################################### -# Power law helpers - - -def power_law_model(params: Sequence[float], N: np.ndarray, D: np.ndarray, use_log_space: bool = True) -> np.ndarray: - """ - Power-law equation: A / N^alpha + B / D^beta + E - - Args: - params: List of parameters [A, B, alpha, beta, E] - N: Number of parameters - D: Number of tokens - use_log_space: Whether to use log space for A and B - """ - if use_log_space: - log_A, log_B, alpha, beta, E = params - A, B = np.exp(log_A), np.exp(log_B) - else: - A, B, alpha, beta, E = params - return A / (N**alpha) + B / (D**beta) + E - - -def power_law_loss( - params: Sequence[float], - N: np.ndarray, - D: np.ndarray, - y: np.ndarray, - use_log_space: bool, - delta: float, - reduction: Callable[[np.ndarray], float] | None = np.sum, -) -> float: - """ - Huber loss for the power-law model. - Args: - params: List of parameters [A, B, alpha, beta, E] - N: Number of parameters - D: Number of tokens - y: Actual loss - use_log_space: if true, residual is set to difference of logs of actual and predicted values - delta: huber loss delta, indicating the quadratic vs. linear loss changepoint. - reduction: Optional argument to change the reduction used on the Huber loss, defaults to sum based on https://arxiv.org/pdf/2404.10102v2 - """ - predictions = power_law_model(params, N, D, use_log_space) - if use_log_space: - residuals = np.log(y) - np.log(predictions) - else: - residuals = y - predictions - return reduction(huber(delta, residuals)) - - -def fit_power_law( - N: np.ndarray, - D: np.ndarray, - y: np.ndarray, - use_log_space: bool = False, - initial_guess: Sequence[float] | None = None, - delta: float = 1e-3, -) -> np.ndarray | tuple[float, float, float, float, float]: - """ - Fit a power law model to the data ((N, D), y). - - Args: - N: Number of parameters - D: Number of tokens - y: Actual loss or metric we want to learn to predict - use_log_space: if true, A and B are in log space *AND* Huber loss is computed in log space. - initial_guess: Initial guess for the parameters - delta: huber loss delta, indicating the quadratic vs. linear loss changepoint. - """ - # Compute the minimum y value to use as the initial guess for E - min_y = np.min(y) - - if use_log_space: - if initial_guess is None: - # Initialize E to max(min_y, 1e-10) to ensure it's positive - initial_guess = [0.0, 0.0, 1.0, 1.0, max(min_y, 1e-10)] # [log_A, log_B, alpha, beta, E] - bounds = [ - (None, None), # log_A unbounded - (None, None), # log_B unbounded - (0, None), # alpha >= 0 - (0, None), # beta >= 0 - (0, None), # E >= 0 - ] - else: - if initial_guess is None: - # Initialize E to max(min_y, 1e-10) to ensure it's positive - initial_guess = [1.0, 1.0, 1.0, 1.0, max(min_y, 1e-10)] # [A, B, alpha, beta, E] - bounds = [ - (0, None), # A >= 0 - (0, None), # B >= 0 - (0, None), # alpha >= 0 - (0, None), # beta >= 0 - (1e-10, None), # E >= 1e-10 to ensure E is positive - ] - - def objective(params): - return power_law_loss(params, N, D, y, use_log_space, delta) - - result = minimize( - objective, - initial_guess, - method="L-BFGS-B", - bounds=bounds, - options={"ftol": OPTIMIZATION_TOLERANCE, "gtol": OPTIMIZATION_TOLERANCE, "maxiter": 2500}, - ) - - if not result.success: - raise RuntimeError(f"Optimization failed: {result.message}") - - # return the fitted parameters, converting log_A and log_B back to A and B if needed - if use_log_space: - log_A, log_B, alpha, beta, E = result.x - A, B = np.exp(log_A), np.exp(log_B) - return A, B, alpha, beta, E - else: - return result.x - - -def predict_power_law(params: Sequence[float], N: np.ndarray, D: np.ndarray) -> np.ndarray: - A, B, alpha, beta, E = params - return A / (N**alpha) + B / (D**beta) + E - - -#################################################################################################### -# Sigmoidal fit helpers - - -def fit_sigmoidal(L: np.ndarray, y: np.ndarray, initial_guess: Sequence[float] | None = None) -> np.ndarray: - """ - Fit a sigmoidal model to the data (L, y). - - Equation: a / (1 + exp(-k * (L - L_0))) + b - - Args: - L: Task loss (input array) - y: Ground-truth task accuracy (output array) - initial_guess: Initial guess for [a, b, k, L_0], defaults to data-driven values - - Returns: - popt: Optimized parameters [a, b, k, L_0] - """ - # Set initial guess if not provided - if initial_guess is None: - y_min, y_max = np.min(y), np.max(y) - a_init = y_max - y_min # amplitude - b_init = y_min # offset - k_init = -1.0 # slope (negative for decreasing sigmoid) - L_0_init = np.mean(L) # midpoint - initial_guess = [a_init, b_init, k_init, L_0_init] - - # Set parameter bounds - lower_bounds = [0, 0, -np.inf, -np.inf] # a > 0, b >= 0, k unbounded below, L_0 unbounded - upper_bounds = [np.inf, np.inf, 0, np.inf] # a unbounded above, b unbounded, k <= 0, L_0 unbounded - bounds = (lower_bounds, upper_bounds) - - def objective(L, a, b, k, L_0): - return predict_sigmoidal([a, b, k, L_0], L) - - # Fit the model using scipy's curve_fit() - # https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.curve_fit.html - popt, _ = curve_fit( - objective, L, y, p0=initial_guess, bounds=bounds, maxfev=15000, method="trf", ftol=OPTIMIZATION_TOLERANCE - ) - - return popt - - -def predict_sigmoidal(params: Sequence[float], L: np.ndarray) -> np.ndarray: - a, b, k, L_0 = params - return a / (1 + np.exp(-k * (L - L_0))) + b - - -#################################################################################################### -# WandB and data processing helpers - - -def pull_metrics_from_wandb( - runs: Sequence[str], - metrics: Sequence[str], - entity: str, - project: str, - summary_fields: Sequence[str] = ("parameter_count",), -) -> pd.DataFrame: - """ - Pulls the metrics from the given runs and returns a DataFrame. - - Args: - runs: List of run IDs - metrics: List of metrics to pull from the runs; these differ depending on the step (unlike summary_fields) - entity: WandB entity - project: WandB project - summary_fields: List of summary fields to pull from the runs - - Returns: - Pandas dataFrame with the metrics - """ - - import wandb - - api = wandb.Api() - - data = [] - for run_id in runs: - run = api.run(f"{entity}/{project}/{run_id}") - run_data = {"run": run.name} - - # Get model configuration - model_dict = run.train_config.get("model", {}) - - run_data["hidden_dim"] = model_dict.get("hidden_dim", 0) - - # get the summary fields - for field in summary_fields: - run_data[field] = run.summary.get(field, None) - - # get the per-step metrics - history = run.history(keys=metrics) - - for i in range(len(history)): - step_data = {m: history[m][i] for m in metrics} - step_data.update(run_data) - step_data["step"] = i - data.append(step_data) - - return pd.DataFrame(data) - - -def filter_zero_d(df: pd.DataFrame, d_key: str = "throughput/total_tokens") -> pd.DataFrame: - """ - Returns a new DataFrame that excludes any rows where the specified - 'd_key' column is zero. - """ - return df[df[d_key] != 0].copy() - - -def extract_scaling_data( - df: pd.DataFrame, - param_count_col: str = "parameter_count", - tokens_col: str = "throughput/total_tokens", - loss_col: str | None = None, - count_embedding_params: bool = False, -) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - Extracts N, D, and y from the given DataFrame. - - Args: - df: DataFrame - param_count_col: Column name for the parameter count - tokens_col: Column name for the tokens - loss_col: Column name for the loss - - Returns: - Tuple of numpy arrays: (N, D, y) where - N = Number of parameters (excluding embedding parameters) - D = Number of tokens - y = Loss - """ - - N = df[param_count_col].values - D = df[tokens_col].values - y = df[loss_col].values if loss_col is not None else None - - # Apply non_embedding_params element-wise - if not count_embedding_params: - N = np.array([non_embedding_params(n, h) for n, h in zip(N, df["hidden_dim"].values, strict=False)]) - - return N, D, y - - -def aggregate_steps( - df: pd.DataFrame, - step_mode: str = "all", - step_range: tuple[int, int] = (1, 5), - group_col: str = "run", -) -> pd.DataFrame: - """ - Aggregates the steps for each run. - - Args: - df: DataFrame - step_mode: how to aggregate the steps - step_range: range of steps to aggregate - group_col: column to group by - - step_mode can be: - - "average": average step_range across each run - - "last": pick the max step within step_range - - "all": keep every step (no grouping) - """ - - if step_mode == "average": - grouped = df.groupby(group_col, as_index=False).mean(numeric_only=True) - return grouped - elif step_mode == "last": - # pick the largest step in the range for each run - def pick_last(g): - last_step_idx = g["step"].idxmax() - return g.loc[last_step_idx] - - grouped = df.groupby(group_col, as_index=False).apply(pick_last) - return grouped.reset_index(drop=True) - elif step_mode == "all": - # no aggregation - return df.copy() - else: - raise ValueError(f"Unknown step_mode: {step_mode}") - - -def non_embedding_params(total_param_count: int, hidden_dim: int, vocab_size: int = llama3_tokenizer_vocab_size): - return total_param_count - 2 * hidden_dim * vocab_size - - -#################################################################################################### -# Projection-specific helpers - - -@dataclass -class ProjectionPoint: - """A point to project to, consisting of number of parameters and tokens""" - - num_params: int - num_tokens: int - - -def get_default_projection_points(count_embedding_params: bool = False) -> list[ProjectionPoint]: - """Default set of model sizes to project to - - Args: - count_embedding_params: Whether to include embedding parameters in parameter count - """ - from experiments.llama import llama_1_4b, llama_8b, llama_13b, llama_24b, llama_70b - - # Base model configs - model_configs = [ - llama_1_4b, - llama_8b, - llama_13b, - llama_24b, - llama_70b, - ] - - # Token multipliers (relative to parameter count - token_multipliers = [0.5, 1, 5, 10, 20, 30, 50, 100] - - points = [] - for config in model_configs: - # Calculate total parameters - total_params = compute_num_parameters(config, llama3_tokenizer_vocab_size) - - # Adjust if we're not counting embedding params - if not count_embedding_params: - total_params = non_embedding_params(total_params, config.hidden_dim) - - # Create points with different token counts - for multiplier in token_multipliers: - num_tokens = int(total_params * multiplier) - points.append(ProjectionPoint(total_params, num_tokens)) - - return points - - -#################################################################################################### -# Plotting helpers - - -def plot_actual_vs_predicted( - y_actual: np.ndarray, - y_predicted: np.ndarray, - title: str = "Actual vs Predicted", - task_metric: str = "eval/paloma/c4_en/bpb", - tokens: np.ndarray | None = None, -) -> None: - """ - Plot actual vs predicted values. task_metric is the name of the metric we are predicting. - """ - plt.figure(figsize=(10, 6)) - - x_values = tokens if tokens is not None else np.arange(len(y_actual)) - - # plot actual and predicted values - plt.plot(x_values, y_actual, label="Actual", marker="o", linestyle="-", linewidth=2) - plt.plot(x_values, y_predicted, label="Predicted", marker="x", linestyle="--", linewidth=2) - - # add labels, legend, and title - plt.xlabel("Tokens" if tokens is not None else "Step") - plt.ylabel("Metric: " + task_metric) - plt.title(title) - plt.legend() - plt.grid(True) - - # Format tick labels to show B/T for billions/trillions - if tokens is not None: - - def format_ticks(x, _): - if x >= 1e12: - return f"{x/1e12:.1f}T" - elif x >= 1e9: - return f"{x/1e9:.1f}B" - else: - return f"{x/1e6:.1f}M" - - plt.gca().xaxis.set_major_formatter(plt.FuncFormatter(format_ticks)) - - return plt - - -def plot_scaling_projections(predicted: np.ndarray, points: list[ProjectionPoint] | None = None): - """ - Plot scaling law predictions vs tokens for specified model sizes. - - Args: - predicted: Array of predicted values - points: List of ProjectionPoint objects containing parameter and token counts - - Returns: - matplotlib.pyplot figure object - """ - plt.figure(figsize=(12, 6)) - unique_params = np.unique([p.num_params for p in points]) - - for param in unique_params: - mask = np.array([p.num_params == param for p in points]) - tokens = np.array([p.num_tokens for p in points])[mask] - preds = predicted[mask] - plt.plot(tokens, preds, "o-", linewidth=2, label=f"{param/1e9:.1f}B params") - - # add annotations for each point - for t, pred in zip(tokens, preds, strict=False): - token_str = f"{t/1e9:.1f}B" if t < 1e11 else f"{t/1e12:.1f}T" - plt.annotate(f"{token_str}, {pred:.3f}", (t, pred), ha="center", va="bottom", fontsize=6) - - plt.xscale("log") - plt.xlabel("Number of Tokens") - plt.ylabel("Predicted Loss") - plt.grid(True) - plt.legend() - return plt - - -#################################################################################################### -# Functions for fitting scaling laws - - -def fit_scaling_laws( - runs: list[str], - loss_metrics: Sequence[str], - accuracy_metrics: Sequence[str] | None, - entity: str, - project: str, - pred_run: str | None = None, - projection_points: list[ProjectionPoint] | None = None, - aggregation: str = "all", - tokens_col: str = "throughput/total_tokens", - param_col: str = "parameter_count", - count_embedding_params: bool = False, - use_log_for_ND: bool = False, - normalize_ND: bool = False, -) -> tuple[dict[str, np.ndarray], tuple[dict, dict, np.ndarray, np.ndarray] | None]: - """Fit scaling laws for both projection and prediction - - Args: - runs: list of run IDs to fit scaling laws for - loss_metrics: list of loss metrics to fit scaling laws for - accuracy_metrics: list of accuracy metrics to fit scaling laws for - entity: WandB entity - project: WandB project - pred_run: run ID to predict scaling laws for- if None, no prediction is done - projection_points: list of ProjectionPoint objects to project to - aggregation: how to aggregate steps within each run (all/last/average) - tokens_col: column name for the number of tokens - param_col: column name for the number of parameters - count_embedding_params: whether to count embedding parameters in calculating N - use_log_for_ND: whether to use log space for N and D - normalize_ND: whether to normalize N and D - - Returns: - tuple of: - - dict of loss metrics and their predictions - - dict of accuracy metrics and their predictions - - numpy array of tokens for x-axis of plots for losses - - numpy array of tokens for x-axis of plots for accuracies - """ - - # First pull for losses - only essential metrics - metrics = [*list(loss_metrics), tokens_col] - loss_df = pull_metrics_from_wandb( - runs=runs, - metrics=metrics, - entity=entity, - project=project, - summary_fields=(param_col,), - ) - - # Process loss data- remove 0-token runs, apply aggregation to the ladder runs' checkpoints (if specified) - loss_df_filtered = filter_zero_d(loss_df, tokens_col) - loss_df_agg = aggregate_steps(loss_df_filtered, step_mode=aggregation) - - # Get N, D - N, D, _ = extract_scaling_data(loss_df_agg, param_col, tokens_col, count_embedding_params=count_embedding_params) - if use_log_for_ND: - N = np.log(N) - D = np.log(D) - if normalize_ND: - N_scale = np.mean(N) - D_scale = np.mean(D) - N = N / N_scale - D = D / D_scale - - # Handle projections - projections = {} - - if projection_points: - N_proj = np.array([point.num_params for point in projection_points]) - D_proj = np.array([point.num_tokens for point in projection_points]) - - if use_log_for_ND: - N_proj, D_proj = np.log(N_proj), np.log(D_proj) - if normalize_ND: - N_proj, D_proj = N_proj / N_scale, D_proj / D_scale - - for loss_metric in loss_metrics: - y = loss_df_agg[loss_metric].values - params = fit_power_law(N, D, y, use_log_space=True) - projections[loss_metric] = predict_power_law(params, N_proj, D_proj) - - predictions = None - if pred_run: - loss_pred_df = pull_metrics_from_wandb( - runs=[pred_run], - metrics=[*list(loss_metrics), tokens_col], - entity=entity, - project=project, - summary_fields=(param_col,), - ) - - loss_pred_filtered = filter_zero_d(loss_pred_df, tokens_col) - loss_pred_agg = aggregate_steps(loss_pred_filtered, step_mode=aggregation) - - N_pred, D_pred, _ = extract_scaling_data( - loss_pred_agg, param_col, tokens_col, count_embedding_params=count_embedding_params - ) - if use_log_for_ND: - N_pred = np.log(N_pred) - D_pred = np.log(D_pred) - if normalize_ND: - N_pred = N_pred / N_scale - D_pred = D_pred / D_scale - - # Fit losses - loss_results = {} - for loss_metric in loss_metrics: - y = loss_df_agg[loss_metric].values - params = fit_power_law(N, D, y, use_log_space=True) - actual_loss = loss_pred_agg[loss_metric].values - predicted_loss = predict_power_law(params, N_pred, D_pred) - loss_results[loss_metric] = (actual_loss, predicted_loss) - - # Second pull for accuracies - accuracy_results = {} - if accuracy_metrics: - acc_df = pull_metrics_from_wandb( - runs=runs, - metrics=[*list(accuracy_metrics), tokens_col], - entity=entity, - project=project, - summary_fields=(param_col,), - ) - acc_pred_df = pull_metrics_from_wandb( - runs=[pred_run], - metrics=[*list(accuracy_metrics), tokens_col], - entity=entity, - project=project, - summary_fields=(param_col,), - ) - - acc_df_filtered = filter_zero_d(acc_df, tokens_col) - acc_df_agg = aggregate_steps(acc_df_filtered, step_mode=aggregation) - acc_pred_filtered = filter_zero_d(acc_pred_df, tokens_col) - acc_pred_agg = aggregate_steps(acc_pred_filtered, step_mode=aggregation) - - # Fit accuracies - accuracy_results = {} - loss_metric, (actual_loss, predicted_loss) = next(iter(loss_results.items())) # use first loss - - # Merge loss and accuracy data on run and tokens - merged_df = pd.merge( - loss_df_agg[["run", tokens_col, loss_metric]], - acc_df_agg[["run", tokens_col, *accuracy_metrics]], - on=["run", tokens_col], - how="inner", - ) - - # Merge prediction data similarly - merged_pred_df = pd.merge( - loss_pred_agg[["run", tokens_col]], - acc_pred_agg[["run", tokens_col, *accuracy_metrics]], - on=["run", tokens_col], - how="inner", - ) - - for acc_metric in accuracy_metrics: - task_losses = merged_df[loss_metric].values - acc = merged_df[acc_metric].values - params = fit_sigmoidal(task_losses, acc) - - acc_pred_actual = merged_pred_df[acc_metric].values - # Get the corresponding predicted losses for these points - pred_indices = loss_pred_agg[tokens_col].isin(merged_pred_df[tokens_col]) - pred_task_losses = predicted_loss[pred_indices] - - acc_preds = predict_sigmoidal(params, pred_task_losses) - accuracy_results[f"{acc_metric}_from_{loss_metric}"] = (acc_pred_actual, acc_preds) - - # Get token counts for plotting - loss_tokens = loss_pred_agg[tokens_col].values - acc_tokens = merged_pred_df[tokens_col].values - - predictions = (loss_results, accuracy_results, loss_tokens, acc_tokens) - - return projections, predictions diff --git a/scripts/migrations/migrate_isoflop_wandb_runs.py b/scripts/migrations/migrate_isoflop_wandb_runs.py new file mode 100644 index 0000000000..94df5bf55a --- /dev/null +++ b/scripts/migrations/migrate_isoflop_wandb_runs.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 +""" +Migrate WandB isoflop runs to match migrated checkpoint paths. + +After scripts/migrations/migrate_isoflop_checkpoints.sh strips the 6-char hash +suffix from checkpoint paths (e.g., 'isoflop-1e+19-d2048-nemo-abc123' becomes +'isoflop-1e+19-d2048-nemo'), this script copies the corresponding WandB runs +to have matching names without the hash suffix. + +This enables eval_metrics_reader.py to find WandB runs by checkpoint name +without needing complex override mappings. +""" + +import argparse +import logging +import re +import sys +from typing import Optional + +try: + import wandb +except ImportError: + print("Error: wandb package not installed. Install with: pip install wandb") + sys.exit(1) + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + + +def copy_wandb_run( + api: wandb.Api, + source_run: wandb.apis.public.Run, + new_name: str, + entity: str, + project: str, + dry_run: bool = True, +) -> bool: + """ + Copy a WandB run to a new run with a different name. + + Args: + api: WandB API instance + source_run: The source run to copy + new_name: The new name for the copied run + entity: WandB entity + project: WandB project + dry_run: If True, don't actually create the run + + Returns: + True if successful (or would be successful in dry run) + """ + if dry_run: + logger.info(f" [DRY RUN] Would copy {source_run.name} -> {new_name}") + return True + + try: + # Initialize a new run with the clean name + new_run = wandb.init( + entity=entity, + project=project, + name=new_name, + id=new_name, # Use name as ID to make it deterministic + resume="never", + config=dict(source_run.config), + tags=list(source_run.tags), + ) + + # Copy summary metrics + summary = dict(source_run.summary) + for key, value in summary.items(): + new_run.summary[key] = value + + logger.info(f" Created new run: {new_name}") + new_run.finish() + return True + + except Exception as e: + logger.error(f" Failed to copy run {source_run.name}: {e}") + return False + + +def migrate_isoflop_wandb_runs( + entity_project: str, + run_name_filter: Optional[str] = None, + dry_run: bool = True, +) -> None: + """ + Migrate WandB isoflop runs by copying them without hash suffixes. + + Args: + entity_project: WandB entity/project (format: 'entity/project') + run_name_filter: Optional filter to only process specific runs + dry_run: If True, don't actually create runs + """ + if "/" not in entity_project: + raise ValueError(f"Invalid entity_project format: {entity_project}. Expected 'entity/project'") + + entity, project = entity_project.split("/", 1) + api = wandb.Api() + + logger.info(f"Querying WandB for isoflop runs in {entity_project}...") + + # Query for isoflop runs with hash suffixes + filters = { + "displayName": {"$regex": "isoflop"}, + "state": "finished", + } + + runs = api.runs(entity_project, filters=filters) + + migrated_count = 0 + skipped_count = 0 + error_count = 0 + + for run in runs: + display_name = run.displayName + + # Check if this run has a hash suffix + if not re.search(r"-[0-9a-fA-F]{6}$", display_name): + logger.debug(f"Skipping {display_name} (no hash suffix)") + skipped_count += 1 + continue + + # Strip the hash to get the clean name + clean_name = re.sub(r"-[0-9a-fA-F]{6}$", "", display_name) + + # Apply filter if specified + if run_name_filter and run_name_filter not in clean_name: + logger.debug(f"Skipping {display_name} (doesn't match filter)") + skipped_count += 1 + continue + + # Check if a run with the clean name already exists + try: + existing = api.run(f"{entity_project}/{clean_name}") + logger.info(f"Skipping {display_name} -> {clean_name} (already exists)") + skipped_count += 1 + continue + except wandb.errors.CommError: + # Run doesn't exist, we can create it + pass + + logger.info(f"Processing: {display_name} -> {clean_name}") + + if copy_wandb_run(api, run, clean_name, entity, project, dry_run): + migrated_count += 1 + else: + error_count += 1 + + logger.info("\n" + "=" * 60) + logger.info("Migration Summary:") + logger.info(f" Migrated: {migrated_count}") + logger.info(f" Skipped: {skipped_count}") + logger.info(f" Errors: {error_count}") + + if dry_run: + logger.info("\nDry run complete. Run with --execute to perform the migration.") + + +def main(): + parser = argparse.ArgumentParser( + description="Migrate WandB isoflop runs to match migrated checkpoint paths", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Dry run (default) + python migrate_isoflop_wandb_runs.py marin-community/marin + + # Execute the migration + python migrate_isoflop_wandb_runs.py marin-community/marin --execute + + # Filter to specific runs + python migrate_isoflop_wandb_runs.py marin-community/marin --filter nemo --execute + """, + ) + + parser.add_argument( + "entity_project", + help="WandB entity/project (format: entity/project)", + ) + + parser.add_argument( + "--execute", + action="store_true", + help="Actually perform the migration (default is dry run)", + ) + + parser.add_argument( + "--filter", + help="Only process runs whose clean name contains this string", + ) + + parser.add_argument( + "--verbose", + action="store_true", + help="Enable verbose logging", + ) + + args = parser.parse_args() + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + try: + migrate_isoflop_wandb_runs( + entity_project=args.entity_project, + run_name_filter=args.filter, + dry_run=not args.execute, + ) + except Exception as e: + logger.error(f"Migration failed: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/uv.lock b/uv.lock index 965f641c45..aed186490a 100644 --- a/uv.lock +++ b/uv.lock @@ -3145,6 +3145,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/17/c6d9dc31001a495cb3c52fa69b22a0d8812880cb853f7c0573e2a5edad82/jaxlib-0.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:659d894d93876e3675c2132d13c3d241f204b21172a58f928b96f654f603f6dc", size = 59323262, upload-time = "2025-10-15T23:10:46.607Z" }, ] +[[package]] +name = "jaxopt" +version = "0.8.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jax" }, + { name = "jaxlib" }, + { name = "numpy" }, + { name = "scipy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3a/da/ff7d7fbd13b8ed5e8458e80308d075fc649062b9f8676d3fc56f2dc99a82/jaxopt-0.8.5.tar.gz", hash = "sha256:2790bd68ef132b216c083a8bc7a2704eceb35a92c0fc0a1e652e79dfb1e9e9ab", size = 121709, upload-time = "2025-04-14T17:59:01.618Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/45/d8/55e0901103c93d57bab3b932294c216f0cbd49054187ce29f8f13808d530/jaxopt-0.8.5-py3-none-any.whl", hash = "sha256:ff221d1a86908ec759eb1e219ee1d12bf208a70707e961bf7401076fe7cf4d5e", size = 172434, upload-time = "2025-04-14T17:59:00.342Z" }, +] + [[package]] name = "jaxtyping" version = "0.3.5" @@ -4317,6 +4332,7 @@ dependencies = [ { name = "google-cloud-storage-transfer" }, { name = "haliax" }, { name = "jax" }, + { name = "jaxopt" }, { name = "levanter", extra = ["serve"] }, { name = "lm-eval" }, { name = "lxml", extra = ["html-clean"] }, From 348339d49c12772ba9ff72ffabb25e92b124d881 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Thu, 18 Dec 2025 11:10:10 -0800 Subject: [PATCH 08/79] More Cleanly Updates --- .../exp2166_scaling_ladder_analysis.py | 67 +-- experiments/isoflop_sweep.py | 141 ++--- lib/marin/pyproject.toml | 2 + lib/marin/src/marin/scaling_laws/__init__.py | 20 +- .../marin/scaling_laws/eval_metrics_reader.py | 5 +- .../marin/scaling_laws/isoflop_analysis.py | 279 ++++++++-- .../src/marin/scaling_laws/scaling_ladder.py | 64 ++- .../src/marin/scaling_laws/scaling_plots.py | 10 +- tests/test_scaling_laws.py | 498 ++++++++++++++++++ uv.lock | 4 + 10 files changed, 891 insertions(+), 199 deletions(-) create mode 100644 tests/test_scaling_laws.py diff --git a/experiments/exp2166_scaling_ladder_analysis.py b/experiments/exp2166_scaling_ladder_analysis.py index 8f460572a4..962ad8981a 100644 --- a/experiments/exp2166_scaling_ladder_analysis.py +++ b/experiments/exp2166_scaling_ladder_analysis.py @@ -12,69 +12,60 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Exp2166: IsoFLOP Analysis and Scaling Ladders for Nemotron, Comma, and Dolma3. +"""Exp2166: Scaling Ladder Analysis for Nemotron and Dolma3. -This experiment runs IsoFLOP analysis on the isoflop training sweeps -for three datasets: +This experiment runs scaling ladder analysis on the isoflop training sweeps +for two datasets: - Nemotron (nemo-wider-depth-adapt) -- Common Pile / Comma (comma-mix) - Dolma3 (dolma3-mix-150b-1025) -The IsoFLOP analysis fits scaling laws to find compute-optimal configurations and -generates visualization plots. It also demonstrates scaling ladder runs (compute-optimal -training runs) that use the predicted configurations. +The scaling ladder: +1. Fits scaling laws from IsoFLOP sweep data to find compute-optimal configurations +2. Generates visualization plots (isoflop curves and scaling fit plots) +3. Optionally trains compute-optimal models at larger target budgets + +The analysis steps depend on completed isoflop training runs from isoflop_sweep.py. +Once complete, results are saved to the output path and uploaded to WandB. """ -from experiments.isoflop_sweep import MARIN_SCALING_SUITES, nemotron_mix, dolma3_mix +from experiments.isoflop_sweep import MARIN_SCALING_SUITES, dolma3_mix, nemotron_mix from marin.execution.executor import executor_main -from marin.scaling_laws import isoflop_analysis_step, scaling_ladder_suite +from marin.scaling_laws import scaling_ladder_suite -# Get training steps for each dataset (eval_tasks=None by default, so only training steps) +# Get training steps and datasets for each suite nemotron_training, _ = MARIN_SCALING_SUITES["nemotron"] -comma_training, _ = MARIN_SCALING_SUITES["common_pile"] dolma3_training, _ = MARIN_SCALING_SUITES["dolma3_mix_150b"] -# --- IsoFLOP analysis-only steps (no scaling ladder rungs) --- - -nemotron_analysis = isoflop_analysis_step( - name="exp2166-isoflop-analysis-nemotron", - training_runs=nemotron_training, - wandb_run_name="exp2166-isoflop-analysis-nemotron", -) - - -dolma3_analysis = isoflop_analysis_step( - name="exp2166-isoflop-analysis-dolma3", - training_runs=dolma3_training, - wandb_run_name="exp2166-isoflop-analysis-dolma3", -) +# --- Scaling Ladder Suites --- +# These analyze completed isoflop training runs and optionally train compute-optimal models +# Target budgets for compute-optimal training runs (beyond the isoflop sweep) +# Set to empty list to only run analysis without training +TARGET_BUDGETS: list[float] = [] -# --- Full scaling ladder suites --- -# These create IsoFLOP analysis + scaling ladder rungs (optimal training runs) for target budgets -# Nemotron suite: analyze isoflop runs, then train optimal models at larger budgets nemotron_suite = scaling_ladder_suite( - name="exp2166-nemo", + name="exp2166-scaling-ladder-nemotron", training_runs=nemotron_training, - target_budgets=[1e21, 3e21], - label="nemo", - dataset=nemotron_mix, + target_budgets=TARGET_BUDGETS, + label="nemo-wider-depth-adapt", + tokenized=nemotron_mix, + wandb_project="marin-analysis", ) -# Dolma3 suite dolma3_suite = scaling_ladder_suite( - name="exp2166-dolma3", + name="exp2166-scaling-ladder-dolma3", training_runs=dolma3_training, - target_budgets=[1e21, 3e21], - label="dolma3", - dataset=dolma3_mix, + target_budgets=TARGET_BUDGETS, + label="dolma3-mix-150b-1025", + tokenized=dolma3_mix, + wandb_project="marin-analysis", ) -all_steps = [nemotron_analysis, dolma3_analysis, *nemotron_suite.all_steps, *dolma3_suite.all_steps] +all_steps = [*nemotron_suite.all_steps, *dolma3_suite.all_steps] if __name__ == "__main__": executor_main(steps=all_steps) diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index 3c615e544a..650f9ca286 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -20,20 +20,17 @@ """ import dataclasses -import os from dataclasses import dataclass, replace from levanter.data.text import LMMixtureDatasetConfig -from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig -from levanter.models.qwen import Qwen3Config from levanter.optim.cautious import CautiousConfig from levanter.optim.config import OptimizerConfig from experiments.evals.evals import default_eval -from experiments.evals.task_configs import MMLU_5_SHOT, EvalTaskConfig +from experiments.evals.task_configs import EvalTaskConfig from experiments.common_pile.tokenize_common_pile import comma_main_mixture from experiments.defaults import default_tokenize, default_train -from experiments.llama import compute_num_parameters, llama3_tokenizer +from experiments.llama import llama3_tokenizer from experiments.metrics.wandb_related import get_vocab_size_for_tokenizer from experiments.pretraining_datasets.simple import downloads from experiments.simple_train_config import SimpleTrainConfig @@ -43,14 +40,13 @@ from marin.processing.tokenize import lm_mixture_data_config from marin.scaling_laws.isoflop_analysis import ( CandidateConfig, - DEFAULT_BUDGETS, IsoFlopSweepConfig, - candidate_configs, - pick_v5p_type, + IsoFlopTrainArgs, + generate_isoflop_train_args, ) -@dataclass +@dataclass(frozen=True) class IsoFlopExperimentConfig(IsoFlopSweepConfig): """Extended config for isoflop experiments with dataset and eval settings. @@ -94,107 +90,70 @@ class IsoFlopExperimentConfig(IsoFlopSweepConfig): ) -def _pick_v5p_type_for_model( - config: Qwen3Config, - hidden: int, - layers: int, - batch: int, - seq_len: int, - vocab: int, -) -> str: - """Select the smallest TPU v5p slice that fits the model in float32.""" - param_count = compute_num_parameters(config, vocab) - return pick_v5p_type(param_count, hidden, layers, batch, seq_len, vocab) - - def generate_isoflop_steps( config: IsoFlopExperimentConfig, experiment_name: str, ) -> tuple[list[ExecutorStep], list[CandidateConfig]]: """Generate executor steps for an ISOFlop sweep. + Uses generate_isoflop_train_args() from the scaling_laws library to get + model configs, optimizer configs, and other arguments, then constructs + ExecutorSteps using default_train(). + Returns: A tuple of: - steps: Training and evaluation ExecutorSteps for the sweep. - candidates: CandidateConfig for each training run (contains budget, hidden_size, num_layers, batch_size, train_steps, learning_rate, etc.) """ + vocab_size = get_vocab_size_for_tokenizer(config.tokenizer) + + # Get training arguments from the library + train_args_list = generate_isoflop_train_args( + sweep_config=config, + experiment_name=experiment_name, + vocab_size=vocab_size, + base_optimizer_config=config.base_optimizer_config, + ) train_steps_list: list[ExecutorStep] = [] eval_steps: list[ExecutorStep] = [] candidates: list[CandidateConfig] = [] - vocab_size = get_vocab_size_for_tokenizer(config.tokenizer) - for budget in config.budgets: - for candidate in candidate_configs(config, budget, vocab_size): - model_cfg = Qwen3Config( - max_seq_len=config.seq_len, - hidden_dim=candidate.hidden_size, - intermediate_dim=candidate.intermediate_dim, - num_heads=candidate.num_heads, - num_kv_heads=candidate.num_kv_heads, - num_layers=candidate.num_layers, - rope=Llama3RotaryEmbeddingsConfig(), - ) - tpu_type = _pick_v5p_type_for_model( - config=model_cfg, - hidden=candidate.hidden_size, - layers=candidate.num_layers, - batch=candidate.batch_size, - seq_len=config.seq_len, - vocab=vocab_size, - ) - optimizer_cfg = replace( - config.base_optimizer_config, - learning_rate=candidate.learning_rate, - beta2=candidate.beta2, - ) - train_cfg = replace( - config.base_train_config, - train_batch_size=candidate.batch_size, - learning_rate=candidate.learning_rate, - num_train_steps=candidate.train_steps, - resources=ResourceConfig.with_tpu(tpu_type), - optimizer_config=optimizer_cfg, - ) + for args in train_args_list: + # Build SimpleTrainConfig from the library-provided arguments + train_cfg = replace( + config.base_train_config, + train_batch_size=args.candidate.batch_size, + learning_rate=args.candidate.learning_rate, + num_train_steps=args.candidate.train_steps, + resources=ResourceConfig.with_tpu(args.tpu_type), + optimizer_config=args.optimizer_config, + ) - run_name = ( - f"isoflop-{budget:.0e}-d{candidate.hidden_size}-" - f"L{candidate.num_layers}-B{candidate.batch_size}-{experiment_name}" - ) - train_step = default_train( - name=run_name, - tokenized=config.tokenized_dataset, - model_config=model_cfg, - train_config=train_cfg, - eval_harness_tasks=[], - tags=( - f"FLOPs={budget:.1e}", - f"d={candidate.hidden_size}", - f"L={candidate.num_layers}", - f"B={candidate.batch_size}", - f"steps={candidate.train_steps}", - f"tpu={tpu_type}", - ), - ) - candidates.append(candidate) - # Reuse checkpoints by pinning every sweep run to a deterministic directory. - static_output_path = os.path.join( - "checkpoints", - "isoflop", - run_name, + # Create training step using default_train + train_step = default_train( + name=args.run_name, + tokenized=config.tokenized_dataset, + model_config=args.model_config, + train_config=train_cfg, + eval_harness_tasks=[], + tags=args.tags, + ) + + # Pin to static output path for checkpoint reuse + train_step = train_step.with_output_path(args.output_path) + train_steps_list.append(train_step) + candidates.append(args.candidate) + + # Evaluation on the latest checkpoint for each ISOFlop run + if config.eval_tasks: + eval_step = default_eval( + train_step, + resource_config=train_cfg.resources, + evals=config.eval_tasks, ) - train_step = train_step.with_output_path(static_output_path) - train_steps_list.append(train_step) - - # Evaluation on the latest checkpoint for each ISOFlop run. - if config.eval_tasks: - eval_step = default_eval( - train_step, - resource_config=train_cfg.resources, - evals=config.eval_tasks, - ) - eval_steps.append(eval_step) + eval_steps.append(eval_step) all_steps: list[ExecutorStep] = [*train_steps_list, *eval_steps] return all_steps, candidates diff --git a/lib/marin/pyproject.toml b/lib/marin/pyproject.toml index 8f8e1f1c91..0548f5e392 100644 --- a/lib/marin/pyproject.toml +++ b/lib/marin/pyproject.toml @@ -67,6 +67,8 @@ test = [ # need this for integration tests "pip", "openai-responses", + # for scaling law plotting tests + "plotly", ] lint = [ "ruff==0.14.3", diff --git a/lib/marin/src/marin/scaling_laws/__init__.py b/lib/marin/src/marin/scaling_laws/__init__.py index 12ffbfaecd..086ec84196 100644 --- a/lib/marin/src/marin/scaling_laws/__init__.py +++ b/lib/marin/src/marin/scaling_laws/__init__.py @@ -18,7 +18,10 @@ IsoFlopAnalysisConfig, IsoFlopAnalysisResult, IsoFlopSweepConfig, + IsoFlopTrainArgs, candidate_configs, + compute_transformer_params, + generate_isoflop_train_args, isoflop_analysis_step, pick_v5p_type, predict_optimal_config, @@ -31,9 +34,12 @@ scaling_ladder_rung_step, scaling_ladder_suite, ) - -# Plotting functions are imported separately to avoid plotly dependency in core module -# from marin.scaling_laws.scaling_plots import create_isoflop_plot, create_scaling_plot, save_plots +from marin.scaling_laws.scaling_plots import ( + create_isoflop_plot, + create_scaling_plot, + save_plots, + upload_plots_to_wandb, +) __all__ = [ # Primary interface (ExecutorStep factories) @@ -42,18 +48,26 @@ "scaling_ladder_rung_step", # Programmatic interface "run_isoflop_analysis", + "generate_isoflop_train_args", # Dataclasses "CandidateConfig", "IsoFlopAnalysisConfig", "IsoFlopAnalysisResult", "IsoFlopSweepConfig", + "IsoFlopTrainArgs", "ScalingLadderSuite", "ScalingLadderRungConfig", # Constants "DEFAULT_BUDGETS", # Utilities "candidate_configs", + "compute_transformer_params", "pick_v5p_type", "predict_optimal_config", "predict_optimal_configs_for_budgets", + # Plotting functions + "create_isoflop_plot", + "create_scaling_plot", + "save_plots", + "upload_plots_to_wandb", ] diff --git a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py index fa5f37e1f2..b0441743d9 100644 --- a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py +++ b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Copyright 2025 Marin Authors -# SPDX-License-Identifier: Apache-2.0 -""" -Base infrastructure for eval metrics analysis. +"""Base infrastructure for eval metrics analysis. This module provides a base config and utilities for analysis jobs that read eval_metrics.jsonl files from completed training runs. Specific diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index fb6d30c200..d5e63430ed 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -37,7 +37,7 @@ import os import re from collections.abc import Iterator, Sequence -from dataclasses import asdict, dataclass, field +from dataclasses import asdict, dataclass, field, replace import fsspec import jax.numpy as jnp @@ -45,6 +45,11 @@ from jaxopt import ScipyMinimize from levanter.utils.flop_utils import lm_flops_per_token +from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig +from levanter.models.qwen import Qwen3Config +from levanter.optim.cautious import CautiousConfig +from levanter.optim.config import OptimizerConfig + from marin.execution.executor import ExecutorStep, InputName, output_path_of, this_output_path from marin.scaling_laws.eval_metrics_reader import ( EvalMetricsAnalysisConfig, @@ -58,12 +63,25 @@ # ---------------- Constants ---------------- DEFAULT_METRIC_KEY = "eval/paloma/c4_en/bpb" SEQ_LEN = 4096 -CANON_LABELS = ["nemo", "comma", "dclm"] + +# Marin tokenizer vocab size (stanford-crfm/marin-tokenizer) +MARIN_TOKENIZER_VOCAB_SIZE = 128256 # ---------------- IsoFLOP Sweep Constants ---------------- -DEFAULT_BUDGETS = [1e18, 3e18, 6e18, 1e19, 3e19, 6e19, 1e20] +DEFAULT_BUDGETS: tuple[float, ...] = (1e18, 3e18, 6e18, 1e19, 3e19, 6e19, 1e20) MLP_RATIO = 4 +# Learning rate scaling: lr = LR_CONSTANT * sqrt(batch_size) / hidden_dim +LR_CONSTANT = 0.33 + +# Head size for attention: num_heads = hidden_dim / HIDDEN_HEAD_RATIO +HIDDEN_HEAD_RATIO = 128 + +# Beta2 scaling for Adam: beta2 = BETA2_BASE ** (batch_size / BETA2_BATCH_DIVISOR) +# Reference: https://arxiv.org/pdf/2507.07101 +BETA2_BASE = 0.98 +BETA2_BATCH_DIVISOR = 128 + # TPU v5p hardware constants for memory estimation HBM_PER_CHIP_GIB = 95 CORES_PER_CHIP = 2 @@ -71,7 +89,7 @@ # ---------------- IsoFLOP Sweep Config ---------------- -@dataclass +@dataclass(frozen=True) class IsoFlopSweepConfig: """Configuration for generating ISOFlop sweep candidate configs. @@ -82,10 +100,10 @@ class IsoFlopSweepConfig: tokenizer: str = "stanford-crfm/marin-tokenizer" """Tokenizer to use (needed for vocab size).""" - budgets: list[float] = field(default_factory=lambda: DEFAULT_BUDGETS.copy()) - """List of FLOP budgets to generate configs for.""" + budgets: tuple[float, ...] = DEFAULT_BUDGETS + """Tuple of FLOP budgets to generate configs for.""" - seq_len: int = 4096 + seq_len: int = SEQ_LEN """Sequence length for training.""" steps_per_run: int = 2**16 @@ -97,10 +115,10 @@ class IsoFlopSweepConfig: base_hidden_layer_ratio: int = 64 """Base ratio for hidden_dim to num_layers calculation.""" - hidden_head_ratio: int = 128 + hidden_head_ratio: int = HIDDEN_HEAD_RATIO """Ratio for hidden_dim to num_heads calculation.""" - lr_constant: float = 0.33 + lr_constant: float = LR_CONSTANT """Constant for learning rate calculation: lr = (lr_constant * sqrt(batch)) / hidden_dim.""" min_hidden_pow: int = 9 @@ -135,6 +153,37 @@ class CandidateConfig: flops_budget: float = 0.0 # the FLOP budget this config was generated for +@dataclass +class IsoFlopTrainArgs: + """Arguments needed to set up an isoflop training run. + + This dataclass contains all the information needed to call default_train() + for an isoflop sweep run. The caller is responsible for constructing the + experiment-specific SimpleTrainConfig from these arguments. + """ + + candidate: CandidateConfig + """The candidate configuration with model/training hyperparameters.""" + + model_config: Qwen3Config + """Levanter model configuration ready to use.""" + + optimizer_config: OptimizerConfig + """Levanter optimizer configuration with learning_rate and beta2 set.""" + + tpu_type: str + """TPU slice type (e.g., 'v5p-8', 'v5p-32').""" + + run_name: str + """Name for the training run.""" + + tags: tuple[str, ...] + """Tags for tracking/filtering runs.""" + + output_path: str + """Static output path for checkpoints.""" + + # ---------------- Candidate Config Generation ---------------- @@ -273,7 +322,7 @@ def candidate_configs( while lr > 0.01: batch_size //= 2 lr = (cfg.lr_constant * math.sqrt(batch_size)) / hidden_size - b2 = 0.98 ** (batch_size / 128) # https://arxiv.org/pdf/2507.07101 + b2 = BETA2_BASE ** (batch_size / BETA2_BATCH_DIVISOR) if batch_size < 8: continue @@ -323,6 +372,171 @@ def candidate_configs( ) +def compute_transformer_params( + hidden_dim: int, + intermediate_dim: int, + num_layers: int, + vocab_size: int, + num_kv_heads: int | None = None, + num_heads: int | None = None, +) -> int: + """Compute parameter count for a transformer model. + + This is a standard approximation for LLaMA-style models with: + - Embedding: vocab_size * hidden_dim + - Per layer: 4 * hidden_dim^2 (attention) + 3 * hidden_dim * intermediate_dim (MLP with GLU) + - Output: vocab_size * hidden_dim (shared with embedding in some models) + """ + # Embedding parameters + embed_params = vocab_size * hidden_dim + + # Attention parameters per layer: Q, K, V, O projections + # For GQA: Q = hidden * hidden, K = hidden * kv_dim, V = hidden * kv_dim, O = hidden * hidden + if num_kv_heads is not None and num_heads is not None: + head_dim = hidden_dim // num_heads + kv_dim = num_kv_heads * head_dim + attn_params_per_layer = ( + hidden_dim * hidden_dim # Q + + hidden_dim * kv_dim # K + + hidden_dim * kv_dim # V + + hidden_dim * hidden_dim # O + ) + else: + # Standard MHA: 4 * hidden^2 + attn_params_per_layer = 4 * hidden_dim * hidden_dim + + # MLP parameters per layer (GLU: gate, up, down) + mlp_params_per_layer = 3 * hidden_dim * intermediate_dim + + # Layer norm parameters (2 per layer + 1 final) + ln_params = (2 * num_layers + 1) * hidden_dim + + # Total + layer_params = (attn_params_per_layer + mlp_params_per_layer) * num_layers + total = embed_params + layer_params + ln_params + + return total + + +def generate_isoflop_train_args( + sweep_config: IsoFlopSweepConfig, + experiment_name: str, + vocab_size: int, + base_optimizer_config: OptimizerConfig | None = None, +) -> list[IsoFlopTrainArgs]: + """Generate training arguments for each candidate in an isoflop sweep. + + This function generates all the arguments needed to call default_train() for + each candidate configuration in the sweep. The caller is responsible for + constructing the experiment-specific SimpleTrainConfig. + + Args: + sweep_config: Configuration for the sweep (budgets, seq_len, etc.) + experiment_name: Name suffix for run names (e.g., 'nemo', 'dclm') + vocab_size: Vocabulary size for the tokenizer + base_optimizer_config: Base optimizer config to modify. If None, uses CautiousConfig defaults. + + Returns: + List of IsoFlopTrainArgs, one per candidate config across all budgets. + + Example: + >>> from marin.scaling_laws import IsoFlopSweepConfig, generate_isoflop_train_args + >>> config = IsoFlopSweepConfig(budgets=(1e18, 1e19)) + >>> train_args = generate_isoflop_train_args(config, "my-experiment", vocab_size=128256) + >>> for args in train_args: + ... # Use args.model_config, args.optimizer_config, etc. with default_train() + ... pass + """ + if base_optimizer_config is None: + base_optimizer_config = CautiousConfig( + learning_rate=1.0, # Placeholder, will be overridden + weight_decay=0.1, + min_lr_ratio=0.0, + warmup=0.1, + beta1=0.95, + beta2=0.98, # Placeholder, will be overridden + epsilon=1e-15, + max_grad_norm=1, + adamc_weight_decay=True, + lr_schedule="linear", + decay=0.2, + ) + + results: list[IsoFlopTrainArgs] = [] + + for budget in sweep_config.budgets: + for candidate in candidate_configs(sweep_config, budget, vocab_size): + # Build model config + model_cfg = Qwen3Config( + max_seq_len=sweep_config.seq_len, + hidden_dim=candidate.hidden_size, + intermediate_dim=candidate.intermediate_dim, + num_heads=candidate.num_heads, + num_kv_heads=candidate.num_kv_heads, + num_layers=candidate.num_layers, + rope=Llama3RotaryEmbeddingsConfig(), + ) + + # Compute parameter count for TPU selection + param_count = compute_transformer_params( + hidden_dim=candidate.hidden_size, + intermediate_dim=candidate.intermediate_dim, + num_layers=candidate.num_layers, + vocab_size=vocab_size, + num_kv_heads=candidate.num_kv_heads, + num_heads=candidate.num_heads, + ) + + # Pick TPU type + tpu_type = pick_v5p_type( + param_count=param_count, + hidden=candidate.hidden_size, + layers=candidate.num_layers, + batch=candidate.batch_size, + seq_len=sweep_config.seq_len, + vocab=vocab_size, + ) + + # Build optimizer config with candidate-specific LR and beta2 + optimizer_cfg = replace( + base_optimizer_config, + learning_rate=candidate.learning_rate, + beta2=candidate.beta2, + ) + + # Generate run name and tags + run_name = ( + f"isoflop-{budget:.0e}-d{candidate.hidden_size}-" + f"L{candidate.num_layers}-B{candidate.batch_size}-{experiment_name}" + ) + + tags = ( + f"FLOPs={budget:.1e}", + f"d={candidate.hidden_size}", + f"L={candidate.num_layers}", + f"B={candidate.batch_size}", + f"steps={candidate.train_steps}", + f"tpu={tpu_type}", + ) + + # Static output path for checkpoint reuse + output_path = os.path.join("checkpoints", "isoflop", run_name) + + results.append( + IsoFlopTrainArgs( + candidate=candidate, + model_config=model_cfg, + optimizer_config=optimizer_cfg, + tpu_type=tpu_type, + run_name=run_name, + tags=tags, + output_path=output_path, + ) + ) + + return results + + # ---------------- Helpers ---------------- @@ -423,8 +637,7 @@ def fit_scaling_laws( if df is None or df.empty: return [], {}, {} - present = list(dict.fromkeys(df["label"].tolist())) - datasets = [lab for lab in CANON_LABELS if lab in present] + [lab for lab in present if lab not in CANON_LABELS] + datasets = list(dict.fromkeys(df["label"].tolist())) buckets = sorted(df.flops.unique()) @@ -555,10 +768,6 @@ def transform_metrics_for_isoflop( label = label_map[exp_name] else: label = exp_name - for canon in CANON_LABELS: - if canon in exp_name.lower(): - label = canon - break records.append( dict( @@ -582,7 +791,7 @@ def predict_optimal_config( target_flops: float, label: str, sweep_config: IsoFlopSweepConfig | None = None, - vocab_size: int = 128256, + vocab_size: int = MARIN_TOKENIZER_VOCAB_SIZE, ) -> CandidateConfig | None: """Predict optimal training config for a target compute budget using fitted scaling laws. @@ -596,7 +805,7 @@ def predict_optimal_config( target_flops: Target compute budget in FLOPs. label: Dataset/experiment label to use for scaling fit. sweep_config: Optional IsoFlopSweepConfig. If None, uses defaults. - vocab_size: Vocabulary size (default: 128256 for marin tokenizer). + vocab_size: Vocabulary size (default: MARIN_TOKENIZER_VOCAB_SIZE for marin tokenizer). Returns: CandidateConfig for the predicted optimal, or None if label not in fits @@ -638,7 +847,7 @@ def predict_optimal_configs_for_budgets( target_budgets: list[float], label: str, sweep_config: IsoFlopSweepConfig | None = None, - vocab_size: int = 128256, + vocab_size: int = MARIN_TOKENIZER_VOCAB_SIZE, ) -> list[CandidateConfig]: """Predict optimal configs for multiple target compute budgets. @@ -650,13 +859,21 @@ def predict_optimal_configs_for_budgets( vocab_size: Vocabulary size. Returns: - List of CandidateConfig for each budget (skips budgets with no valid config). + List of CandidateConfig for each budget. + + Raises: + RuntimeError: If any budget cannot be predicted (to prevent silent failures). """ configs = [] for budget in target_budgets: config = predict_optimal_config(scaling_fits, budget, label, sweep_config, vocab_size) - if config is not None: - configs.append(config) + if config is None: + raise RuntimeError( + f"Failed to predict optimal config for budget {budget:.2e} FLOPs " + f"with label '{label}'. Check that the label exists in scaling_fits " + f"and that the budget is within a valid range." + ) + configs.append(config) return configs @@ -759,12 +976,12 @@ def _run_isoflop_analysis_step(config: IsoFlopAnalysisConfig) -> None: hidden_size=rec["hidden_dim"], intermediate_dim=rec["hidden_dim"] * MLP_RATIO, num_layers=rec["num_layers"], - num_heads=max(1, rec["hidden_dim"] // 128), - num_kv_heads=max(1, rec["hidden_dim"] // 128), + num_heads=max(1, rec["hidden_dim"] // HIDDEN_HEAD_RATIO), + num_kv_heads=max(1, rec["hidden_dim"] // HIDDEN_HEAD_RATIO), batch_size=rec["batch_size"], train_steps=int(rec["optimal_tokens"] / (rec["batch_size"] * SEQ_LEN)), - learning_rate=(0.33 * math.sqrt(rec["batch_size"])) / rec["hidden_dim"], - beta2=0.98 ** (rec["batch_size"] / 128), + learning_rate=(LR_CONSTANT * math.sqrt(rec["batch_size"])) / rec["hidden_dim"], + beta2=BETA2_BASE ** (rec["batch_size"] / BETA2_BATCH_DIVISOR), tokens=rec["optimal_tokens"], flops_budget=rec["flops"], ) @@ -852,7 +1069,7 @@ def isoflop_analysis_step( Example: >>> from marin.scaling_laws import isoflop_analysis_step - >>> analysis = scaling_ladder_step( + >>> analysis = isoflop_analysis_step( ... name="my-scaling-analysis", ... training_runs=my_training_steps, ... ) @@ -952,12 +1169,12 @@ def run_isoflop_analysis( hidden_size=rec["hidden_dim"], intermediate_dim=rec["hidden_dim"] * MLP_RATIO, num_layers=rec["num_layers"], - num_heads=max(1, rec["hidden_dim"] // 128), - num_kv_heads=max(1, rec["hidden_dim"] // 128), + num_heads=max(1, rec["hidden_dim"] // HIDDEN_HEAD_RATIO), + num_kv_heads=max(1, rec["hidden_dim"] // HIDDEN_HEAD_RATIO), batch_size=rec["batch_size"], train_steps=int(rec["optimal_tokens"] / (rec["batch_size"] * SEQ_LEN)), - learning_rate=(0.33 * math.sqrt(rec["batch_size"])) / rec["hidden_dim"], - beta2=0.98 ** (rec["batch_size"] / 128), + learning_rate=(LR_CONSTANT * math.sqrt(rec["batch_size"])) / rec["hidden_dim"], + beta2=BETA2_BASE ** (rec["batch_size"] / BETA2_BATCH_DIVISOR), tokens=rec["optimal_tokens"], flops_budget=rec["flops"], ) diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py index 9a0a669427..bf3fa960a1 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -76,8 +76,8 @@ class ScalingLadderRungConfig: label: str """Dataset label to use for scaling fit (e.g., 'nemo', 'comma', 'dclm').""" - dataset: str - """Path to tokenized dataset for training.""" + tokenized: str | LMMixtureDatasetConfig + """Tokenized dataset for training. Can be a path or LMMixtureDatasetConfig.""" output_path: str """Where to write training outputs.""" @@ -91,6 +91,9 @@ class ScalingLadderRungConfig: sweep_config: IsoFlopSweepConfig | None = None """Optional sweep config for predict_optimal_config. Uses defaults if None.""" + use_default_validation: bool = True + """Whether to use the default validation sets (Paloma).""" + def load_scaling_fits(analysis_path: str) -> dict[str, tuple[float, float]]: """Load scaling fits from an IsoFLOP analysis output.""" @@ -149,7 +152,6 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: from experiments.defaults import _prepare_data_config from experiments.llama import compute_num_parameters from experiments.metrics.wandb_related import get_vocab_size_for_tokenizer - from marin.processing.tokenize import lm_mixture_data_config from marin.scaling_laws.isoflop_analysis import pick_v5p_type from marin.training.training import TrainLmOnPodConfig, run_levanter_train_lm @@ -202,11 +204,8 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: ) # Prepare data config (uses same helper as default_train) - data_config = lm_mixture_data_config( - components={"train": config.dataset}, - weights={"train": 1.0}, - ) - pretraining_data = _prepare_data_config(data_config, use_default_validation=True) + # Accepts both string paths and LMMixtureDatasetConfig + pretraining_data = _prepare_data_config(config.tokenized, use_default_validation=config.use_default_validation) # Build TrainLmConfig (mirrors default_train structure) train_config = TrainLmConfig( @@ -251,9 +250,11 @@ def scaling_ladder_rung_step( analysis_step: ExecutorStep, target_budget: float, label: str, - dataset: InputName | ExecutorStep | LMMixtureDatasetConfig, + tokenized: InputName | ExecutorStep | LMMixtureDatasetConfig, tokenizer: str = "stanford-crfm/marin-tokenizer", seq_len: int = 4096, + use_default_validation: bool = True, + override_output_path: str | None = None, ) -> ExecutorStep: """Create an ExecutorStep for one rung of the scaling ladder. @@ -265,34 +266,39 @@ def scaling_ladder_rung_step( analysis_step: The IsoFLOP analysis step to read fits from target_budget: Target compute budget in FLOPs label: Dataset label to use for scaling fit (e.g., 'nemo', 'comma') - dataset: Tokenized dataset to train on + tokenized: Tokenized dataset to train on (path, ExecutorStep, or LMMixtureDatasetConfig) tokenizer: Tokenizer to use seq_len: Sequence length for training + use_default_validation: Whether to use the default validation sets (Paloma) + override_output_path: Optional override for the output path Returns: ExecutorStep configured to run one optimal training run """ - # Resolve dataset path - if isinstance(dataset, ExecutorStep): - dataset_path = output_path_of(dataset) - elif isinstance(dataset, LMMixtureDatasetConfig): - # For mixture configs, we'll need to handle this differently - # For now, just serialize it somehow - this is a limitation - raise NotImplementedError("LMMixtureDatasetConfig not yet supported for scaling_ladder_rung_step") + # Resolve tokenized data - works like default_train + if isinstance(tokenized, ExecutorStep): + resolved_tokenized: str | LMMixtureDatasetConfig = output_path_of(tokenized) + elif isinstance(tokenized, LMMixtureDatasetConfig): + # Pass through LMMixtureDatasetConfig directly + resolved_tokenized = tokenized else: - dataset_path = dataset + # InputName or string path + resolved_tokenized = tokenized + + output_path = override_output_path if override_output_path is not None else this_output_path() config = ScalingLadderRungConfig( analysis_output_path=output_path_of(analysis_step), target_budget=target_budget, label=label, - dataset=dataset_path, - output_path=this_output_path(), + tokenized=resolved_tokenized, + output_path=output_path, tokenizer=tokenizer, seq_len=seq_len, + use_default_validation=use_default_validation, ) - return ExecutorStep( + step = ExecutorStep( name=os.path.join("checkpoints", name), fn=run_scaling_ladder_rung, config=config, @@ -300,6 +306,11 @@ def scaling_ladder_rung_step( pip_dependency_groups=["tokenize_train"], ) + if override_output_path is not None: + step = step.with_output_path(override_output_path) + + return step + # ---------------- Scaling Ladder Suite ---------------- @@ -329,7 +340,7 @@ def scaling_ladder_suite( training_runs: Sequence[ExecutorStep | InputName], target_budgets: Sequence[float], label: str, - dataset: InputName | ExecutorStep, + tokenized: InputName | ExecutorStep | LMMixtureDatasetConfig, tokenizer: str = "stanford-crfm/marin-tokenizer", seq_len: int = 4096, metric_key: str = "eval/paloma/c4_en/bpb", @@ -338,6 +349,7 @@ def scaling_ladder_suite( upload_to_wandb: bool = True, wandb_entity: str = "marin-community", wandb_project: str = "marin-analysis", + use_default_validation: bool = True, ) -> ScalingLadderSuite: """Create a complete scaling ladder: IsoFLOP analysis + optimal training runs. @@ -353,7 +365,7 @@ def scaling_ladder_suite( training_runs: IsoFLOP training run ExecutorSteps to analyze target_budgets: Target compute budgets (in FLOPs) for optimal training label: Dataset label to use for scaling fit (e.g., 'nemo', 'comma') - dataset: Tokenized dataset for optimal training runs + tokenized: Tokenized dataset for optimal training runs (path, ExecutorStep, or LMMixtureDatasetConfig) tokenizer: Tokenizer to use seq_len: Sequence length for training metric_key: Which metric to use for loss @@ -362,6 +374,7 @@ def scaling_ladder_suite( upload_to_wandb: Whether to upload plots to WandB wandb_entity: WandB entity for uploads wandb_project: WandB project for uploads + use_default_validation: Whether to use the default validation sets (Paloma) Returns: ScalingLadderSuite containing the analysis step and optimal training steps @@ -372,7 +385,7 @@ def scaling_ladder_suite( ... training_runs=isoflop_training_steps, ... target_budgets=[1e21, 3e21, 1e22], ... label="nemo", - ... dataset=nemotron_tokenized, + ... tokenized=nemotron_tokenized, ... ) >>> all_steps = [*isoflop_training_steps, *suite.all_steps] """ @@ -399,9 +412,10 @@ def scaling_ladder_suite( analysis_step=analysis, target_budget=budget, label=label, - dataset=dataset, + tokenized=tokenized, tokenizer=tokenizer, seq_len=seq_len, + use_default_validation=use_default_validation, ) optimal_runs.append(run_step) diff --git a/lib/marin/src/marin/scaling_laws/scaling_plots.py b/lib/marin/src/marin/scaling_laws/scaling_plots.py index bf49f312ae..a3a257388d 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_plots.py +++ b/lib/marin/src/marin/scaling_laws/scaling_plots.py @@ -83,13 +83,11 @@ _SCALE_MARKER = dict(symbol="circle", size=9, color=PALETTE[0]) _SCALE_LINE = dict(dash="dot", width=2, color=PALETTE[0]) -CANON_LABELS = ["nemo", "comma", "dclm"] - def create_isoflop_plot( df: pd.DataFrame, minima_records: list[dict], - fit_curves: dict[tuple[str, float], tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]], + fit_curves: dict[tuple[str, float], tuple[float, float, float]], ) -> go.Figure: """Create the IsoFLOP plot showing loss vs tokens for each compute budget. @@ -104,8 +102,7 @@ def create_isoflop_plot( if df.empty: return go.Figure() - present = list(dict.fromkeys(df["label"].tolist())) - datasets = [lab for lab in CANON_LABELS if lab in present] + [lab for lab in present if lab not in CANON_LABELS] + datasets = list(dict.fromkeys(df["label"].tolist())) buckets = sorted(df.flops.unique()) bucket_color = {C: PALETTE[i % len(PALETTE)] for i, C in enumerate(buckets)} @@ -212,8 +209,7 @@ def create_scaling_plot( for rec in minima_records: by_lab.setdefault(rec["label"], []).append(rec) - present = list(by_lab.keys()) - datasets = [lab for lab in CANON_LABELS if lab in present] + [lab for lab in present if lab not in CANON_LABELS] + datasets = list(by_lab.keys()) fig = go.Figure() diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py new file mode 100644 index 0000000000..373ad56815 --- /dev/null +++ b/tests/test_scaling_laws.py @@ -0,0 +1,498 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the scaling_laws module.""" + +import jax.numpy as jnp +import pandas as pd + +from marin.scaling_laws.isoflop_analysis import ( + BETA2_BASE, + BETA2_BATCH_DIVISOR, + DEFAULT_BUDGETS, + HIDDEN_HEAD_RATIO, + LR_CONSTANT, + MARIN_TOKENIZER_VOCAB_SIZE, + MLP_RATIO, + SEQ_LEN, + CandidateConfig, + IsoFlopSweepConfig, + IsoFlopTrainArgs, + candidate_configs, + compute_total_flops, + compute_transformer_params, + generate_isoflop_train_args, + parse_isoflop_run_name, + predict_optimal_config, + robust_quad_logx, + round_to_power_of_two, +) + + +# --- round_to_power_of_two tests --- + + +def test_round_to_power_of_two_exact_powers(): + """Test that exact powers of two are unchanged.""" + assert round_to_power_of_two(1) == 1 + assert round_to_power_of_two(2) == 2 + assert round_to_power_of_two(4) == 4 + assert round_to_power_of_two(8) == 8 + assert round_to_power_of_two(16) == 16 + + +def test_round_to_power_of_two_rounds_up(): + """Test that non-powers round up to nearest power of two.""" + assert round_to_power_of_two(3) == 4 + assert round_to_power_of_two(5) == 8 + assert round_to_power_of_two(7) == 8 + assert round_to_power_of_two(9) == 16 + + +def test_round_to_power_of_two_small_values(): + """Test that small/zero values become 1.""" + assert round_to_power_of_two(0.5) == 1 + assert round_to_power_of_two(0.1) == 1 + assert round_to_power_of_two(0) == 1 + + +def test_round_to_power_of_two_large_values(): + """Test rounding for large values.""" + assert round_to_power_of_two(100) == 128 + assert round_to_power_of_two(1000) == 1024 + assert round_to_power_of_two(1025) == 2048 + + +# --- compute_total_flops tests --- + + +def test_compute_total_flops_larger_model_uses_more_flops(): + """Test that larger models use more FLOPs.""" + small_flops = compute_total_flops( + batch=32, + num_layers=12, + hidden=512, + intermediate=2048, + num_kv_heads=8, + num_heads=8, + steps=1000, + seq_len=4096, + vocab_size=128256, + ) + large_flops = compute_total_flops( + batch=32, + num_layers=24, + hidden=1024, + intermediate=4096, + num_kv_heads=16, + num_heads=16, + steps=1000, + seq_len=4096, + vocab_size=128256, + ) + assert large_flops > small_flops + + +def test_compute_total_flops_linear_in_batch_and_steps(): + """Test that FLOPs scale linearly with batch size and steps.""" + base_flops = compute_total_flops( + batch=32, + num_layers=12, + hidden=512, + intermediate=2048, + num_kv_heads=8, + num_heads=8, + steps=1000, + seq_len=4096, + vocab_size=128256, + ) + double_batch_flops = compute_total_flops( + batch=64, + num_layers=12, + hidden=512, + intermediate=2048, + num_kv_heads=8, + num_heads=8, + steps=1000, + seq_len=4096, + vocab_size=128256, + ) + double_steps_flops = compute_total_flops( + batch=32, + num_layers=12, + hidden=512, + intermediate=2048, + num_kv_heads=8, + num_heads=8, + steps=2000, + seq_len=4096, + vocab_size=128256, + ) + assert abs(double_batch_flops - 2 * base_flops) / base_flops < 0.01 + assert abs(double_steps_flops - 2 * base_flops) / base_flops < 0.01 + + +# --- parse_isoflop_run_name tests --- + + +def test_parse_isoflop_run_name_basic(): + """Test parsing a standard isoflop run name.""" + result = parse_isoflop_run_name("isoflop-1e+19-d2048-L16-B1024-nemo-wider-depth-adapt") + assert result is not None + assert result["flops"] == 1e19 + assert result["d"] == 2048 + assert result["L"] == 16 + assert result["B"] == 1024 + assert result["experiment_name"] == "nemo-wider-depth-adapt" + + +def test_parse_isoflop_run_name_with_hash_suffix(): + """Test parsing run name with hash suffix.""" + result = parse_isoflop_run_name("isoflop-1e+18-d512-L8-B128-dclm-a1b2c3") + assert result is not None + assert result["flops"] == 1e18 + assert result["d"] == 512 + assert result["L"] == 8 + assert result["B"] == 128 + assert result["experiment_name"] == "dclm" + + +def test_parse_isoflop_run_name_invalid_format(): + """Test that invalid formats return None.""" + assert parse_isoflop_run_name("not-a-valid-name") is None + assert parse_isoflop_run_name("isoflop-missing-parts") is None + assert parse_isoflop_run_name("") is None + + +# --- candidate_configs tests --- + + +def test_candidate_configs_generates_candidates(): + """Test that candidate_configs generates at least one config.""" + cfg = IsoFlopSweepConfig() + candidates = list(candidate_configs(cfg, 1e19, MARIN_TOKENIZER_VOCAB_SIZE)) + assert len(candidates) > 0 + + +def test_candidate_configs_within_tolerance(): + """Test that generated configs are within FLOP tolerance.""" + cfg = IsoFlopSweepConfig(flop_tolerance=0.01) + budget = 1e19 + for candidate in candidate_configs(cfg, budget, MARIN_TOKENIZER_VOCAB_SIZE): + achieved = compute_total_flops( + candidate.batch_size, + candidate.num_layers, + candidate.hidden_size, + candidate.intermediate_dim, + candidate.num_kv_heads, + candidate.num_heads, + candidate.train_steps, + cfg.seq_len, + MARIN_TOKENIZER_VOCAB_SIZE, + ) + relative_error = abs(achieved - budget) / budget + assert relative_error <= cfg.flop_tolerance + + +def test_candidate_configs_fields_populated(): + """Test that all candidate fields are properly populated.""" + cfg = IsoFlopSweepConfig() + candidates = list(candidate_configs(cfg, 1e19, MARIN_TOKENIZER_VOCAB_SIZE)) + assert len(candidates) > 0 + + for candidate in candidates: + assert candidate.hidden_size > 0 + assert candidate.intermediate_dim == candidate.hidden_size * MLP_RATIO + assert candidate.num_layers > 0 + assert candidate.num_heads > 0 + assert candidate.num_kv_heads > 0 + assert candidate.batch_size >= 8 + assert candidate.train_steps > 0 + assert candidate.learning_rate > 0 + assert 0 < candidate.beta2 < 1 + assert candidate.tokens > 0 + assert candidate.flops_budget == 1e19 + + +# --- robust_quad_logx tests --- + + +def test_robust_quad_logx_fits_quadratic(): + """Test that robust_quad_logx recovers known coefficients.""" + x = jnp.array([1e9, 1e10, 1e11, 1e12]) + L = jnp.log10(x) + y = 0.1 * L**2 - 2 * L + 20 + + a, b, c = robust_quad_logx(x, y) + + assert abs(a - 0.1) < 0.01 + assert abs(b - (-2)) < 0.1 + assert abs(c - 20) < 0.5 + + +def test_robust_quad_logx_handles_noise(): + """Test that robust_quad_logx handles noisy data.""" + x = jnp.array([1e9, 1e10, 1e11, 1e12, 1e13]) + L = jnp.log10(x) + y_clean = 0.05 * L**2 - 1.5 * L + 15 + noise = jnp.array([0.01, -0.02, 0.015, -0.01, 0.005]) + y = y_clean + noise + + a, b, c = robust_quad_logx(x, y) + + assert abs(a - 0.05) < 0.05 + assert abs(b - (-1.5)) < 0.5 + + +# --- predict_optimal_config tests --- + + +def test_predict_optimal_config_unknown_label_returns_none(): + """Test that unknown labels return None.""" + scaling_fits = {"nemo": (0.5, 1e5)} + result = predict_optimal_config( + scaling_fits=scaling_fits, + target_flops=1e21, + label="unknown", + ) + assert result is None + + +def test_predict_optimal_config_valid_label(): + """Test prediction with a valid label.""" + scaling_fits = {"nemo": (0.5, 1e5)} + result = predict_optimal_config( + scaling_fits=scaling_fits, + target_flops=1e20, + label="nemo", + ) + assert result is None or isinstance(result, CandidateConfig) + + +# --- Constants tests --- + + +def test_constants_default_budgets(): + """Test that DEFAULT_BUDGETS is valid.""" + assert len(DEFAULT_BUDGETS) > 0 + assert all(b > 0 for b in DEFAULT_BUDGETS) + assert list(DEFAULT_BUDGETS) == sorted(DEFAULT_BUDGETS) + + +def test_constants_have_expected_values(): + """Test that constants have expected values.""" + assert SEQ_LEN == 4096 + assert MARIN_TOKENIZER_VOCAB_SIZE == 128256 + assert LR_CONSTANT == 0.33 + assert HIDDEN_HEAD_RATIO == 128 + assert BETA2_BASE == 0.98 + assert BETA2_BATCH_DIVISOR == 128 + assert MLP_RATIO == 4 + + +# --- compute_transformer_params tests --- + + +def test_compute_transformer_params_returns_positive_int(): + """Test that compute_transformer_params returns a positive integer.""" + params = compute_transformer_params( + hidden_dim=512, + intermediate_dim=2048, + num_layers=12, + vocab_size=128256, + num_kv_heads=8, + num_heads=8, + ) + assert params > 0 + assert isinstance(params, int) + + +def test_compute_transformer_params_scales_with_hidden_dim(): + """Test that params scale with hidden dimension.""" + small = compute_transformer_params( + hidden_dim=512, + intermediate_dim=2048, + num_layers=12, + vocab_size=128256, + ) + large = compute_transformer_params( + hidden_dim=1024, + intermediate_dim=4096, + num_layers=12, + vocab_size=128256, + ) + assert large > small + + +def test_compute_transformer_params_scales_with_layers(): + """Test that params scale with number of layers.""" + shallow = compute_transformer_params( + hidden_dim=512, + intermediate_dim=2048, + num_layers=6, + vocab_size=128256, + ) + deep = compute_transformer_params( + hidden_dim=512, + intermediate_dim=2048, + num_layers=12, + vocab_size=128256, + ) + assert deep > shallow + + +# --- generate_isoflop_train_args tests --- + + +def test_generate_isoflop_train_args_returns_list(): + """Test that generate_isoflop_train_args returns a list of IsoFlopTrainArgs.""" + config = IsoFlopSweepConfig(budgets=(1e18,)) + result = generate_isoflop_train_args( + sweep_config=config, + experiment_name="test-experiment", + vocab_size=MARIN_TOKENIZER_VOCAB_SIZE, + ) + assert isinstance(result, list) + assert len(result) > 0 + assert all(isinstance(arg, IsoFlopTrainArgs) for arg in result) + + +def test_generate_isoflop_train_args_populates_fields(): + """Test that all required fields are populated.""" + config = IsoFlopSweepConfig(budgets=(1e18,)) + result = generate_isoflop_train_args( + sweep_config=config, + experiment_name="test-experiment", + vocab_size=MARIN_TOKENIZER_VOCAB_SIZE, + ) + for args in result: + assert args.candidate is not None + assert args.candidate.hidden_size > 0 + assert args.candidate.num_layers > 0 + + assert args.model_config is not None + assert args.model_config.hidden_dim == args.candidate.hidden_size + assert args.model_config.num_layers == args.candidate.num_layers + + assert args.optimizer_config is not None + assert args.optimizer_config.learning_rate == args.candidate.learning_rate + + assert args.tpu_type.startswith("v5p-") + assert "isoflop" in args.run_name + assert "test-experiment" in args.run_name + assert len(args.tags) > 0 + assert args.output_path.startswith("checkpoints/isoflop/") + + +def test_generate_isoflop_train_args_more_budgets_more_configs(): + """Test that more budgets produce more configs.""" + config_single = IsoFlopSweepConfig(budgets=(1e18,)) + config_multi = IsoFlopSweepConfig(budgets=(1e18, 1e19)) + + result_single = generate_isoflop_train_args( + sweep_config=config_single, + experiment_name="test", + vocab_size=MARIN_TOKENIZER_VOCAB_SIZE, + ) + result_multi = generate_isoflop_train_args( + sweep_config=config_multi, + experiment_name="test", + vocab_size=MARIN_TOKENIZER_VOCAB_SIZE, + ) + assert len(result_multi) > len(result_single) + + +def test_generate_isoflop_train_args_unique_run_names(): + """Test that all run names are unique.""" + config = IsoFlopSweepConfig(budgets=(1e18, 1e19)) + result = generate_isoflop_train_args( + sweep_config=config, + experiment_name="test", + vocab_size=MARIN_TOKENIZER_VOCAB_SIZE, + ) + run_names = [args.run_name for args in result] + assert len(run_names) == len(set(run_names)) + + +def test_generate_isoflop_train_args_includes_experiment_name(): + """Test that experiment name appears in run names.""" + config = IsoFlopSweepConfig(budgets=(1e18,)) + result = generate_isoflop_train_args( + sweep_config=config, + experiment_name="my-custom-experiment", + vocab_size=MARIN_TOKENIZER_VOCAB_SIZE, + ) + for args in result: + assert "my-custom-experiment" in args.run_name + + +# --- Plotting tests --- + + +def test_create_isoflop_plot_empty_data(): + """Test that create_isoflop_plot handles empty data.""" + from marin.scaling_laws import create_isoflop_plot + + df = pd.DataFrame() + fig = create_isoflop_plot(df, [], {}) + assert fig is not None + + +def test_create_isoflop_plot_with_data(): + """Test create_isoflop_plot with sample data.""" + from marin.scaling_laws import create_isoflop_plot + + df = pd.DataFrame( + { + "tokens": [1e9, 2e9, 3e9], + "loss": [2.5, 2.3, 2.2], + "flops": [1e18, 1e18, 1e18], + "params": [1e8, 1e8, 1e8], + "name": ["run1", "run2", "run3"], + "label": ["nemo", "nemo", "nemo"], + } + ) + minima_records = [ + { + "label": "nemo", + "flops": 1e18, + "optimal_tokens": 2e9, + "loss_at_optimal": 2.3, + "optimal_params": 1e8, + } + ] + fit_curves = {("nemo", 1e18): (0.1, -1.0, 3.0)} + fig = create_isoflop_plot(df, minima_records, fit_curves) + assert fig is not None + + +def test_create_scaling_plot_empty(): + """Test that create_scaling_plot handles empty data.""" + from marin.scaling_laws import create_scaling_plot + + fig = create_scaling_plot([], {}) + assert fig is not None + + +def test_create_scaling_plot_with_data(): + """Test create_scaling_plot with sample data.""" + from marin.scaling_laws import create_scaling_plot + + minima_records = [ + {"label": "nemo", "flops": 1e18, "optimal_tokens": 1e9}, + {"label": "nemo", "flops": 1e19, "optimal_tokens": 5e9}, + ] + scaling_fits = {"nemo": (0.5, 1e5)} + fig = create_scaling_plot(minima_records, scaling_fits) + assert fig is not None diff --git a/uv.lock b/uv.lock index aed186490a..236561e869 100644 --- a/uv.lock +++ b/uv.lock @@ -4418,6 +4418,7 @@ dev = [ { name = "mypy" }, { name = "openai-responses" }, { name = "pip" }, + { name = "plotly" }, { name = "pylatexenc" }, { name = "pymdown-extensions" }, { name = "pyrefly" }, @@ -4466,6 +4467,7 @@ metrics = [ test = [ { name = "openai-responses" }, { name = "pip" }, + { name = "plotly" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, @@ -4554,6 +4556,7 @@ dev = [ { name = "mypy", specifier = ">=1.4.1" }, { name = "openai-responses" }, { name = "pip" }, + { name = "plotly" }, { name = "pylatexenc" }, { name = "pymdown-extensions", specifier = ">=10.0.0" }, { name = "pyrefly", specifier = "==0.40.0" }, @@ -4600,6 +4603,7 @@ metrics = [{ name = "google-cloud-logging" }] test = [ { name = "openai-responses" }, { name = "pip" }, + { name = "plotly" }, { name = "pytest", specifier = ">=8.3.2" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, From 84b14aa8544a954ef16ca9589595d5c1950ff147 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Sun, 21 Dec 2025 09:19:42 -0800 Subject: [PATCH 09/79] Remove Experiment Deps --- experiments/isoflop_sweep.py | 4 +- experiments/metrics/wandb_related.py | 21 +------ .../src/marin/processing/tokenize/__init__.py | 1 + .../marin/processing/tokenize/data_configs.py | 26 ++++++++ .../marin/scaling_laws/isoflop_analysis.py | 62 ++++++++++++------- .../src/marin/scaling_laws/scaling_ladder.py | 48 +++++++------- 6 files changed, 93 insertions(+), 69 deletions(-) diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index 650f9ca286..bc825dfa15 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -31,17 +31,15 @@ from experiments.common_pile.tokenize_common_pile import comma_main_mixture from experiments.defaults import default_tokenize, default_train from experiments.llama import llama3_tokenizer -from experiments.metrics.wandb_related import get_vocab_size_for_tokenizer from experiments.pretraining_datasets.simple import downloads from experiments.simple_train_config import SimpleTrainConfig from experiments.tootsie.exp1295_32b import nemotron_mix from fray.cluster import ResourceConfig from marin.execution.executor import ExecutorStep, InputName, executor_main -from marin.processing.tokenize import lm_mixture_data_config +from marin.processing.tokenize import get_vocab_size_for_tokenizer, lm_mixture_data_config from marin.scaling_laws.isoflop_analysis import ( CandidateConfig, IsoFlopSweepConfig, - IsoFlopTrainArgs, generate_isoflop_train_args, ) diff --git a/experiments/metrics/wandb_related.py b/experiments/metrics/wandb_related.py index b74d902fe4..2c478bbf15 100644 --- a/experiments/metrics/wandb_related.py +++ b/experiments/metrics/wandb_related.py @@ -18,6 +18,7 @@ from typing import Any import wandb +from marin.processing.tokenize import get_vocab_size_for_tokenizer from marin.utilities.wandb_utils import WANDB_ENTITY, WANDB_PROJECT logger = logging.getLogger(__name__) @@ -145,26 +146,6 @@ def get_all_runs_over_period( return None -def get_vocab_size_for_tokenizer(tokenizer: str) -> int | None: - logger.info(f"Tokenizer:{tokenizer}") - if tokenizer == "EleutherAI/gpt-neox-20b": - vocab_size = 50_257 - elif tokenizer == "meta-llama/Meta-Llama-3.1-8B": - vocab_size = 128_256 - elif tokenizer == "stanford-crfm/marin-tokenizer": - vocab_size = 128_256 - elif tokenizer == "meta-llama/Llama-2-7b": - vocab_size = 32_000 - elif tokenizer == "gpt2": - vocab_size = 50_257 - else: - logger.error(f"Unknown tokenizer: {tokenizer}") - return None - - logger.info(f"Vocab size: {vocab_size}") - return vocab_size - - def count_params_for_run(run_id: str, entity=WANDB_ENTITY, project=WANDB_PROJECT) -> int | None: """ Retrieves the number of parameters for a specific WandB run. diff --git a/lib/marin/src/marin/processing/tokenize/__init__.py b/lib/marin/src/marin/processing/tokenize/__init__.py index 3413b817f1..0706bc9d49 100644 --- a/lib/marin/src/marin/processing/tokenize/__init__.py +++ b/lib/marin/src/marin/processing/tokenize/__init__.py @@ -15,6 +15,7 @@ from .data_configs import ( TokenizerStep, add_validation_sets_to_mixture, + get_vocab_size_for_tokenizer, lm_data_config, lm_mixture_data_config, mixture_for_evaluation, diff --git a/lib/marin/src/marin/processing/tokenize/data_configs.py b/lib/marin/src/marin/processing/tokenize/data_configs.py index 4076f9c01b..23c0f78523 100644 --- a/lib/marin/src/marin/processing/tokenize/data_configs.py +++ b/lib/marin/src/marin/processing/tokenize/data_configs.py @@ -34,6 +34,14 @@ logger = logging.getLogger(__name__) +_KNOWN_VOCAB_SIZES: dict[str, int] = { + "EleutherAI/gpt-neox-20b": 50_257, + "meta-llama/Meta-Llama-3.1-8B": 128_256, + "stanford-crfm/marin-tokenizer": 128_256, + "meta-llama/Llama-2-7b": 32_000, + "gpt2": 50_257, +} + def step_to_lm_mixture_component(step: TokenizerStep | TokenizeConfig, include_raw_paths: bool) -> LMDatasetSourceConfig: """ @@ -333,6 +341,24 @@ def _load_tokenizer(tokenizer_name: str) -> transformers.PreTrainedTokenizer: return load_tokenizer_with_backoff(tokenizer_name) +@lru_cache(maxsize=128) +def get_vocab_size_for_tokenizer(tokenizer_name: str) -> int: + """Return the vocabulary size for a tokenizer name. + + Args: + tokenizer_name: HuggingFace tokenizer name or path. + + Returns: + Vocabulary size for the tokenizer. + """ + resolved_name = unwrap_versioned_value(tokenizer_name) + if resolved_name in _KNOWN_VOCAB_SIZES: + return _KNOWN_VOCAB_SIZES[resolved_name] + + tokenizer = _load_tokenizer(resolved_name) + return len(tokenizer) + + def _are_tokenizers_equivalent(tokenizer1: str, tokenizer2: str) -> bool: """Compare two tokenizers by loading them and comparing their vocabularies and token IDs""" tokenizer1 = unwrap_versioned_value(tokenizer1) diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index d5e63430ed..a7f9250896 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -37,7 +37,8 @@ import os import re from collections.abc import Iterator, Sequence -from dataclasses import asdict, dataclass, field, replace +from dataclasses import asdict, dataclass, replace +from typing import NotRequired, TypedDict import fsspec import jax.numpy as jnp @@ -57,7 +58,6 @@ read_metrics_dataframe, ) - logger = logging.getLogger(__name__) # ---------------- Constants ---------------- @@ -184,6 +184,29 @@ class IsoFlopTrainArgs: """Static output path for checkpoints.""" +# ---------------- Typed Records ---------------- + + +class _NearestConfig(TypedDict): + hidden_dim: int + num_layers: int + batch_size: int + params: float + + +class MinimaRecord(TypedDict): + label: str + flops: float + optimal_tokens: float + loss_at_optimal: float + hidden_dim: int + num_layers: int + batch_size: int + optimal_params: float + scaling_alpha: NotRequired[float] + scaling_A: NotRequired[float] + + # ---------------- Candidate Config Generation ---------------- @@ -478,14 +501,7 @@ def generate_isoflop_train_args( ) # Compute parameter count for TPU selection - param_count = compute_transformer_params( - hidden_dim=candidate.hidden_size, - intermediate_dim=candidate.intermediate_dim, - num_layers=candidate.num_layers, - vocab_size=vocab_size, - num_kv_heads=candidate.num_kv_heads, - num_heads=candidate.num_heads, - ) + param_count = model_cfg.total_trainable_params(vocab_size) # Pick TPU type tpu_type = pick_v5p_type( @@ -598,7 +614,7 @@ def _compute_optimal_params(flops: float, tokens: float) -> float: return flops / (6 * tokens) -def _find_nearest_config(df: pd.DataFrame, flops: float, tokens: float) -> dict: +def _find_nearest_config(df: pd.DataFrame, flops: float, tokens: float) -> _NearestConfig: """Find the nearest actual config from the dataframe to use as template.""" sub = df[df.flops == flops] if sub.empty: @@ -610,10 +626,10 @@ def _find_nearest_config(df: pd.DataFrame, flops: float, tokens: float) -> dict: meta = parse_isoflop_run_name(run_name) return { - "hidden_dim": meta["d"] if meta else 0, - "num_layers": meta["L"] if meta else 0, - "batch_size": meta["B"] if meta else 0, - "params": row.get("params", _compute_optimal_params(flops, tokens)), + "hidden_dim": int(meta["d"]) if meta else 0, + "num_layers": int(meta["L"]) if meta else 0, + "batch_size": int(meta["B"]) if meta else 0, + "params": float(row.get("params", _compute_optimal_params(flops, tokens))), } @@ -622,7 +638,11 @@ def _find_nearest_config(df: pd.DataFrame, flops: float, tokens: float) -> dict: def fit_scaling_laws( df: pd.DataFrame, -) -> tuple[list[dict], dict[str, tuple[float, float]], dict[tuple[str, float], tuple[float, float, float]]]: +) -> tuple[ + list[MinimaRecord], + dict[str, tuple[float, float]], + dict[tuple[str, float], tuple[float, float, float]], +]: """ Fit scaling laws and extract optimal configurations. @@ -641,8 +661,8 @@ def fit_scaling_laws( buckets = sorted(df.flops.unique()) - minima_records = [] - fit_curves = {} + minima_records: list[MinimaRecord] = [] + fit_curves: dict[tuple[str, float], tuple[float, float, float]] = {} # Fit quadratic for each (label, budget) and find minima for lab in datasets: @@ -680,8 +700,8 @@ def fit_scaling_laws( ) # Fit scaling law N* ~ A * C^alpha per dataset - scaling_fits = {} - by_lab = {} + scaling_fits: dict[str, tuple[float, float]] = {} + by_lab: dict[str, list[MinimaRecord]] = {} for rec in minima_records: by_lab.setdefault(rec["label"], []).append(rec) @@ -893,7 +913,7 @@ class IsoFlopAnalysisResult: isoflop_df: pd.DataFrame """Transformed dataframe used for analysis.""" - minima_records: list[dict] + minima_records: list[MinimaRecord] """Raw minima records with detailed info for each optimum.""" fit_curves: dict[tuple[str, float], tuple[float, float, float]] diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py index bf3fa960a1..b44a6046d6 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -42,18 +42,30 @@ import os from collections.abc import Sequence from dataclasses import dataclass +from datetime import timedelta import fsspec +import jmp +from experiments.defaults import _prepare_data_config +from fray.cluster import ResourceConfig +from levanter.checkpoint import CheckpointerConfig from levanter.data.text import LMMixtureDatasetConfig from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig +from levanter.main.train_lm import TrainLmConfig from levanter.models.qwen import Qwen3Config +from levanter.optim.cautious import CautiousConfig +from levanter.tracker.wandb import WandbConfig +from levanter.trainer import TrainerConfig from marin.execution.executor import ExecutorStep, InputName, output_path_of, this_output_path +from marin.processing.tokenize import get_vocab_size_for_tokenizer from marin.scaling_laws.isoflop_analysis import ( CandidateConfig, IsoFlopSweepConfig, + pick_v5p_type, predict_optimal_config, ) +from marin.training.training import TrainLmOnPodConfig, run_levanter_train_lm logger = logging.getLogger(__name__) @@ -76,8 +88,8 @@ class ScalingLadderRungConfig: label: str """Dataset label to use for scaling fit (e.g., 'nemo', 'comma', 'dclm').""" - tokenized: str | LMMixtureDatasetConfig - """Tokenized dataset for training. Can be a path or LMMixtureDatasetConfig.""" + tokenized: InputName | str | LMMixtureDatasetConfig + """Tokenized dataset for training. Can be a path, InputName, or LMMixtureDatasetConfig.""" output_path: str """Where to write training outputs.""" @@ -104,13 +116,16 @@ def load_scaling_fits(analysis_path: str) -> dict[str, tuple[float, float]]: result = json.load(f) # Convert lists back to tuples - return {k: tuple(v) for k, v in result["scaling_fits"].items()} + scaling_fits: dict[str, tuple[float, float]] = {} + for key, value in result["scaling_fits"].items(): + if len(value) != 2: + raise ValueError(f"Expected 2 scaling fit values for '{key}', got {len(value)}") + scaling_fits[key] = (float(value[0]), float(value[1])) + return scaling_fits def get_optimal_candidate(config: ScalingLadderRungConfig) -> CandidateConfig: """Load scaling fits and predict optimal config for target budget.""" - from experiments.metrics.wandb_related import get_vocab_size_for_tokenizer - scaling_fits = load_scaling_fits(config.analysis_output_path) vocab_size = get_vocab_size_for_tokenizer(config.tokenizer) @@ -124,8 +139,7 @@ def get_optimal_candidate(config: ScalingLadderRungConfig) -> CandidateConfig: if candidate is None: raise RuntimeError( - f"Could not find optimal config for budget {config.target_budget:.2e} " - f"and label '{config.label}'" + f"Could not find optimal config for budget {config.target_budget:.2e} " f"and label '{config.label}'" ) return candidate @@ -139,22 +153,6 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: 2. Predicts the optimal config for the target budget 3. Trains a model with that config using the same infrastructure as default_train """ - from datetime import timedelta - - import jmp - from fray.cluster import ResourceConfig - from levanter.checkpoint import CheckpointerConfig - from levanter.main.train_lm import TrainLmConfig - from levanter.optim.cautious import CautiousConfig - from levanter.tracker.wandb import WandbConfig - from levanter.trainer import TrainerConfig - - from experiments.defaults import _prepare_data_config - from experiments.llama import compute_num_parameters - from experiments.metrics.wandb_related import get_vocab_size_for_tokenizer - from marin.scaling_laws.isoflop_analysis import pick_v5p_type - from marin.training.training import TrainLmOnPodConfig, run_levanter_train_lm - # Get the optimal candidate config from analysis candidate = get_optimal_candidate(config) vocab_size = get_vocab_size_for_tokenizer(config.tokenizer) @@ -178,7 +176,7 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: ) # Pick TPU type based on memory requirements - param_count = compute_num_parameters(model_cfg, vocab_size) + param_count = model_cfg.total_trainable_params(vocab_size) tpu_type = pick_v5p_type( param_count, candidate.hidden_size, @@ -277,7 +275,7 @@ def scaling_ladder_rung_step( """ # Resolve tokenized data - works like default_train if isinstance(tokenized, ExecutorStep): - resolved_tokenized: str | LMMixtureDatasetConfig = output_path_of(tokenized) + resolved_tokenized: InputName | str | LMMixtureDatasetConfig = output_path_of(tokenized) elif isinstance(tokenized, LMMixtureDatasetConfig): # Pass through LMMixtureDatasetConfig directly resolved_tokenized = tokenized From 6ffaee89a6cf90f495b660f9d98229af98e9b198 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Sun, 21 Dec 2025 09:52:46 -0800 Subject: [PATCH 10/79] LoB --- .../src/marin/scaling_laws/scaling_ladder.py | 68 +++++++++++++++---- 1 file changed, 53 insertions(+), 15 deletions(-) diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py index b44a6046d6..e1dcd92155 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -46,7 +46,6 @@ import fsspec import jmp -from experiments.defaults import _prepare_data_config from fray.cluster import ResourceConfig from levanter.checkpoint import CheckpointerConfig from levanter.data.text import LMMixtureDatasetConfig @@ -59,6 +58,8 @@ from marin.execution.executor import ExecutorStep, InputName, output_path_of, this_output_path from marin.processing.tokenize import get_vocab_size_for_tokenizer +from marin.processing.tokenize.data_configs import add_validation_sets_to_mixture, lm_data_config +from marin.processing.tokenize.tokenize import TokenizeConfig from marin.scaling_laws.isoflop_analysis import ( CandidateConfig, IsoFlopSweepConfig, @@ -69,6 +70,42 @@ logger = logging.getLogger(__name__) +# Type alias for tokenizer steps +TokenizerStep = ExecutorStep[TokenizeConfig] + + +def _prepare_data_config( + tokenized: InputName | str | LMMixtureDatasetConfig, + validation_sets: dict[str, TokenizerStep] | None = None, +) -> LMMixtureDatasetConfig: + """Prepare a tokenized dataset for training. + + This is a local helper that prepares data configs without depending on + experiment-specific validation sets. Callers should pass validation sets + explicitly if needed. + + Args: + tokenized: The tokenized dataset - can be an InputName, path string, + or an already-configured LMMixtureDatasetConfig. + validation_sets: Optional dict of validation sets to add. If None, + no validation sets are added. + + Returns: + LMMixtureDatasetConfig ready for training. + """ + if isinstance(tokenized, LMMixtureDatasetConfig): + pretraining_data = tokenized + if validation_sets: + pretraining_data = add_validation_sets_to_mixture(pretraining_data, validation_sets) + else: + # InputName or string path + pretraining_data = lm_data_config( + training_set=tokenized, + validation_sets=validation_sets, + permutation_type="feistel", + ) + return pretraining_data + @dataclass(frozen=True) class ScalingLadderRungConfig: @@ -77,6 +114,10 @@ class ScalingLadderRungConfig: This config references an IsoFLOP analysis step and specifies the target compute budget. At runtime, the optimal config is loaded from the analysis output. + + Note: If you need validation sets, pass an LMMixtureDatasetConfig with + validation sets already configured. This module does not handle default + validation sets to avoid experiment-specific dependencies. """ analysis_output_path: str @@ -89,7 +130,8 @@ class ScalingLadderRungConfig: """Dataset label to use for scaling fit (e.g., 'nemo', 'comma', 'dclm').""" tokenized: InputName | str | LMMixtureDatasetConfig - """Tokenized dataset for training. Can be a path, InputName, or LMMixtureDatasetConfig.""" + """Tokenized dataset for training. Can be a path, InputName, or LMMixtureDatasetConfig. + If validation sets are needed, pass an LMMixtureDatasetConfig with them pre-configured.""" output_path: str """Where to write training outputs.""" @@ -103,9 +145,6 @@ class ScalingLadderRungConfig: sweep_config: IsoFlopSweepConfig | None = None """Optional sweep config for predict_optimal_config. Uses defaults if None.""" - use_default_validation: bool = True - """Whether to use the default validation sets (Paloma).""" - def load_scaling_fits(analysis_path: str) -> dict[str, tuple[float, float]]: """Load scaling fits from an IsoFLOP analysis output.""" @@ -201,9 +240,10 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: decay=0.2, ) - # Prepare data config (uses same helper as default_train) + # Prepare data config # Accepts both string paths and LMMixtureDatasetConfig - pretraining_data = _prepare_data_config(config.tokenized, use_default_validation=config.use_default_validation) + # If validation sets are needed, they should be pre-configured in the LMMixtureDatasetConfig + pretraining_data = _prepare_data_config(config.tokenized) # Build TrainLmConfig (mirrors default_train structure) train_config = TrainLmConfig( @@ -251,7 +291,6 @@ def scaling_ladder_rung_step( tokenized: InputName | ExecutorStep | LMMixtureDatasetConfig, tokenizer: str = "stanford-crfm/marin-tokenizer", seq_len: int = 4096, - use_default_validation: bool = True, override_output_path: str | None = None, ) -> ExecutorStep: """Create an ExecutorStep for one rung of the scaling ladder. @@ -264,10 +303,11 @@ def scaling_ladder_rung_step( analysis_step: The IsoFLOP analysis step to read fits from target_budget: Target compute budget in FLOPs label: Dataset label to use for scaling fit (e.g., 'nemo', 'comma') - tokenized: Tokenized dataset to train on (path, ExecutorStep, or LMMixtureDatasetConfig) + tokenized: Tokenized dataset to train on. Can be an ExecutorStep, InputName, + or LMMixtureDatasetConfig. If validation sets are needed, pass an + LMMixtureDatasetConfig with them pre-configured. tokenizer: Tokenizer to use seq_len: Sequence length for training - use_default_validation: Whether to use the default validation sets (Paloma) override_output_path: Optional override for the output path Returns: @@ -293,7 +333,6 @@ def scaling_ladder_rung_step( output_path=output_path, tokenizer=tokenizer, seq_len=seq_len, - use_default_validation=use_default_validation, ) step = ExecutorStep( @@ -347,7 +386,6 @@ def scaling_ladder_suite( upload_to_wandb: bool = True, wandb_entity: str = "marin-community", wandb_project: str = "marin-analysis", - use_default_validation: bool = True, ) -> ScalingLadderSuite: """Create a complete scaling ladder: IsoFLOP analysis + optimal training runs. @@ -363,7 +401,9 @@ def scaling_ladder_suite( training_runs: IsoFLOP training run ExecutorSteps to analyze target_budgets: Target compute budgets (in FLOPs) for optimal training label: Dataset label to use for scaling fit (e.g., 'nemo', 'comma') - tokenized: Tokenized dataset for optimal training runs (path, ExecutorStep, or LMMixtureDatasetConfig) + tokenized: Tokenized dataset for optimal training runs. Can be an ExecutorStep, + InputName, or LMMixtureDatasetConfig. If validation sets are needed, + pass an LMMixtureDatasetConfig with them pre-configured. tokenizer: Tokenizer to use seq_len: Sequence length for training metric_key: Which metric to use for loss @@ -372,7 +412,6 @@ def scaling_ladder_suite( upload_to_wandb: Whether to upload plots to WandB wandb_entity: WandB entity for uploads wandb_project: WandB project for uploads - use_default_validation: Whether to use the default validation sets (Paloma) Returns: ScalingLadderSuite containing the analysis step and optimal training steps @@ -413,7 +452,6 @@ def scaling_ladder_suite( tokenized=tokenized, tokenizer=tokenizer, seq_len=seq_len, - use_default_validation=use_default_validation, ) optimal_runs.append(run_step) From fd68b3e28af1743ae9da4b97a04517558df42dc0 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Sun, 21 Dec 2025 10:18:11 -0800 Subject: [PATCH 11/79] More Refactor --- lib/marin/src/marin/scaling_laws/__init__.py | 5 + .../marin/scaling_laws/eval_metrics_reader.py | 47 +------ .../marin/scaling_laws/isoflop_analysis.py | 126 +++++++++++------- .../src/marin/scaling_laws/scaling_ladder.py | 33 +---- 4 files changed, 91 insertions(+), 120 deletions(-) diff --git a/lib/marin/src/marin/scaling_laws/__init__.py b/lib/marin/src/marin/scaling_laws/__init__.py index 086ec84196..c1bb789643 100644 --- a/lib/marin/src/marin/scaling_laws/__init__.py +++ b/lib/marin/src/marin/scaling_laws/__init__.py @@ -19,6 +19,8 @@ IsoFlopAnalysisResult, IsoFlopSweepConfig, IsoFlopTrainArgs, + build_model_config, + build_optimizer_config, candidate_configs, compute_transformer_params, generate_isoflop_train_args, @@ -59,6 +61,9 @@ "ScalingLadderRungConfig", # Constants "DEFAULT_BUDGETS", + # Shared builders + "build_model_config", + "build_optimizer_config", # Utilities "candidate_configs", "compute_transformer_params", diff --git a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py index b0441743d9..a28fc6461f 100644 --- a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py +++ b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py @@ -23,7 +23,7 @@ import json import os from dataclasses import dataclass -from collections.abc import Callable, Sequence +from collections.abc import Sequence import fsspec import pandas as pd @@ -37,7 +37,7 @@ logger = logging.getLogger(__name__) -from marin.execution.executor import ExecutorStep, InputName, output_path_of, this_output_path +# Note: ExecutorStep and friends are used by callers, not needed here def extract_run_name_from_path(path: str) -> str: @@ -184,46 +184,3 @@ def read_metrics_dataframe(config: EvalMetricsAnalysisConfig) -> pd.DataFrame: logger.info(f"Loaded {len(all_records)} evaluation records from {len(config.training_runs)} runs") logger.info(f"Available columns: {list(df.columns)}") return df - - -def create_analysis_step( - name: str, - training_runs: Sequence[ExecutorStep | InputName], - analysis_fn: Callable[[EvalMetricsAnalysisConfig], None], - config_class: type[EvalMetricsAnalysisConfig], - description: str | None = None, - **config_kwargs, -) -> ExecutorStep: - """ - Create an ExecutorStep for an eval metrics analysis. - - This is the factory for creating analysis steps. It: - - Converts training ExecutorSteps to blocking dependencies - - Creates the appropriate config subclass - - Returns an ExecutorStep that runs the analysis - - Args: - name: Name for this executor step - training_runs: Training run ExecutorSteps (creates blocking dependencies) - analysis_fn: The analysis function to run - config_class: The config class (EvalMetricsAnalysisConfig or subclass) - description: Optional description - **config_kwargs: Additional kwargs passed to config_class - - Returns: - ExecutorStep configured to run the analysis - """ - run_paths = [output_path_of(run) if isinstance(run, ExecutorStep) else run for run in training_runs] - - config = config_class( - training_runs=run_paths, - output_path=this_output_path(), - **config_kwargs, - ) - - return ExecutorStep( - name=name, - fn=analysis_fn, - config=config, - description=description or f"Analyze eval metrics from {len(training_runs)} training runs", - ) diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index a7f9250896..1d178cf5e8 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -441,6 +441,78 @@ def compute_transformer_params( return total +# ---------------- Shared Model/Optimizer Builders ---------------- + + +def build_model_config(candidate: CandidateConfig, seq_len: int = SEQ_LEN) -> Qwen3Config: + """Build a Qwen3Config from a CandidateConfig. + + This is the shared builder used by both generate_isoflop_train_args() and + scaling_ladder's run_scaling_ladder_rung() to ensure consistent model configs. + """ + return Qwen3Config( + max_seq_len=seq_len, + hidden_dim=candidate.hidden_size, + intermediate_dim=candidate.intermediate_dim, + num_heads=candidate.num_heads, + num_kv_heads=candidate.num_kv_heads, + num_layers=candidate.num_layers, + rope=Llama3RotaryEmbeddingsConfig(), + ) + + +def build_optimizer_config(candidate: CandidateConfig) -> CautiousConfig: + """Build optimizer config from a CandidateConfig. + + This is the shared builder used by both generate_isoflop_train_args() and + scaling_ladder's run_scaling_ladder_rung() to ensure consistent optimizer configs. + """ + return CautiousConfig( + learning_rate=candidate.learning_rate, + weight_decay=0.1, + min_lr_ratio=0.0, + warmup=0.1, + beta1=0.95, + beta2=candidate.beta2, + epsilon=1e-15, + max_grad_norm=1, + adamc_weight_decay=True, + lr_schedule="linear", + decay=0.2, + ) + + +def _minima_to_candidates(minima_records: list[MinimaRecord]) -> list[CandidateConfig]: + """Convert minima records to CandidateConfig objects. + + This is used by both _run_isoflop_analysis_step() and run_isoflop_analysis() + to convert the fitted minima into usable candidate configs. + """ + configs = [] + for rec in minima_records: + if rec["hidden_dim"] == 0: + continue + configs.append( + CandidateConfig( + hidden_size=rec["hidden_dim"], + intermediate_dim=rec["hidden_dim"] * MLP_RATIO, + num_layers=rec["num_layers"], + num_heads=max(1, rec["hidden_dim"] // HIDDEN_HEAD_RATIO), + num_kv_heads=max(1, rec["hidden_dim"] // HIDDEN_HEAD_RATIO), + batch_size=rec["batch_size"], + train_steps=int(rec["optimal_tokens"] / (rec["batch_size"] * SEQ_LEN)), + learning_rate=(LR_CONSTANT * math.sqrt(rec["batch_size"])) / rec["hidden_dim"], + beta2=BETA2_BASE ** (rec["batch_size"] / BETA2_BATCH_DIVISOR), + tokens=rec["optimal_tokens"], + flops_budget=rec["flops"], + ) + ) + return configs + + +# ---------------- Training Args Generation ---------------- + + def generate_isoflop_train_args( sweep_config: IsoFlopSweepConfig, experiment_name: str, @@ -489,16 +561,8 @@ def generate_isoflop_train_args( for budget in sweep_config.budgets: for candidate in candidate_configs(sweep_config, budget, vocab_size): - # Build model config - model_cfg = Qwen3Config( - max_seq_len=sweep_config.seq_len, - hidden_dim=candidate.hidden_size, - intermediate_dim=candidate.intermediate_dim, - num_heads=candidate.num_heads, - num_kv_heads=candidate.num_kv_heads, - num_layers=candidate.num_layers, - rope=Llama3RotaryEmbeddingsConfig(), - ) + # Build model config using shared builder + model_cfg = build_model_config(candidate, sweep_config.seq_len) # Compute parameter count for TPU selection param_count = model_cfg.total_trainable_params(vocab_size) @@ -987,25 +1051,8 @@ def _run_isoflop_analysis_step(config: IsoFlopAnalysisConfig) -> None: for label, (alpha, A) in scaling_fits.items(): logger.info(f" {label}: N* = {A:.2e} * C^{alpha:.3f}") - # Convert minima to CandidateConfigs - configs = [] - for rec in minima_records: - if rec["hidden_dim"] == 0: - continue - candidate = CandidateConfig( - hidden_size=rec["hidden_dim"], - intermediate_dim=rec["hidden_dim"] * MLP_RATIO, - num_layers=rec["num_layers"], - num_heads=max(1, rec["hidden_dim"] // HIDDEN_HEAD_RATIO), - num_kv_heads=max(1, rec["hidden_dim"] // HIDDEN_HEAD_RATIO), - batch_size=rec["batch_size"], - train_steps=int(rec["optimal_tokens"] / (rec["batch_size"] * SEQ_LEN)), - learning_rate=(LR_CONSTANT * math.sqrt(rec["batch_size"])) / rec["hidden_dim"], - beta2=BETA2_BASE ** (rec["batch_size"] / BETA2_BATCH_DIVISOR), - tokens=rec["optimal_tokens"], - flops_budget=rec["flops"], - ) - configs.append(candidate) + # Convert minima to CandidateConfigs using shared helper + configs = _minima_to_candidates(minima_records) result = IsoFlopAnalysisResult( configs=configs, @@ -1180,25 +1227,8 @@ def run_isoflop_analysis( # Fit scaling laws and extract optima minima_records, scaling_fits, fit_curves = fit_scaling_laws(isoflop_df) - # Convert minima records to CandidateConfig objects - configs = [] - for rec in minima_records: - if rec["hidden_dim"] == 0: - continue - candidate = CandidateConfig( - hidden_size=rec["hidden_dim"], - intermediate_dim=rec["hidden_dim"] * MLP_RATIO, - num_layers=rec["num_layers"], - num_heads=max(1, rec["hidden_dim"] // HIDDEN_HEAD_RATIO), - num_kv_heads=max(1, rec["hidden_dim"] // HIDDEN_HEAD_RATIO), - batch_size=rec["batch_size"], - train_steps=int(rec["optimal_tokens"] / (rec["batch_size"] * SEQ_LEN)), - learning_rate=(LR_CONSTANT * math.sqrt(rec["batch_size"])) / rec["hidden_dim"], - beta2=BETA2_BASE ** (rec["batch_size"] / BETA2_BATCH_DIVISOR), - tokens=rec["optimal_tokens"], - flops_budget=rec["flops"], - ) - configs.append(candidate) + # Convert minima records to CandidateConfig objects using shared helper + configs = _minima_to_candidates(minima_records) return IsoFlopAnalysisResult( configs=configs, diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py index e1dcd92155..4760e8c414 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -49,10 +49,7 @@ from fray.cluster import ResourceConfig from levanter.checkpoint import CheckpointerConfig from levanter.data.text import LMMixtureDatasetConfig -from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig from levanter.main.train_lm import TrainLmConfig -from levanter.models.qwen import Qwen3Config -from levanter.optim.cautious import CautiousConfig from levanter.tracker.wandb import WandbConfig from levanter.trainer import TrainerConfig @@ -63,6 +60,8 @@ from marin.scaling_laws.isoflop_analysis import ( CandidateConfig, IsoFlopSweepConfig, + build_model_config, + build_optimizer_config, pick_v5p_type, predict_optimal_config, ) @@ -203,16 +202,8 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: f" learning_rate={candidate.learning_rate:.6f}, tokens={candidate.tokens:.2e}" ) - # Build model config - model_cfg = Qwen3Config( - max_seq_len=config.seq_len, - hidden_dim=candidate.hidden_size, - intermediate_dim=candidate.intermediate_dim, - num_heads=candidate.num_heads, - num_kv_heads=candidate.num_kv_heads, - num_layers=candidate.num_layers, - rope=Llama3RotaryEmbeddingsConfig(), - ) + # Build model config using shared builder + model_cfg = build_model_config(candidate, config.seq_len) # Pick TPU type based on memory requirements param_count = model_cfg.total_trainable_params(vocab_size) @@ -225,20 +216,8 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: vocab_size, ) - # Build optimizer config (matches isoflop_sweep defaults) - optimizer_cfg = CautiousConfig( - learning_rate=candidate.learning_rate, - weight_decay=0.1, - min_lr_ratio=0.0, - warmup=0.1, - beta1=0.95, - beta2=candidate.beta2, - epsilon=1e-15, - max_grad_norm=1, - adamc_weight_decay=True, - lr_schedule="linear", - decay=0.2, - ) + # Build optimizer config using shared builder + optimizer_cfg = build_optimizer_config(candidate) # Prepare data config # Accepts both string paths and LMMixtureDatasetConfig From 614827daee935f3358c7ebb5738bf98e06a115ae Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 22 Dec 2025 08:14:40 -0800 Subject: [PATCH 12/79] Refactoring --- .../marin/scaling_laws/eval_metrics_reader.py | 2 - .../marin/scaling_laws/isoflop_analysis.py | 74 +++---------------- .../src/marin/scaling_laws/scaling_ladder.py | 48 ++---------- 3 files changed, 17 insertions(+), 107 deletions(-) diff --git a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py index a28fc6461f..3b05eae7b5 100644 --- a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py +++ b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py @@ -37,8 +37,6 @@ logger = logging.getLogger(__name__) -# Note: ExecutorStep and friends are used by callers, not needed here - def extract_run_name_from_path(path: str) -> str: """Extract run name (last component) from a checkpoint path. diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index 1d178cf5e8..ef6352de54 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -187,13 +187,6 @@ class IsoFlopTrainArgs: # ---------------- Typed Records ---------------- -class _NearestConfig(TypedDict): - hidden_dim: int - num_layers: int - batch_size: int - params: float - - class MinimaRecord(TypedDict): label: str flops: float @@ -673,30 +666,6 @@ def objective(params): return float(result[0]), float(result[1]), float(result[2]) -def _compute_optimal_params(flops: float, tokens: float) -> float: - """Compute optimal parameters from C = 6 * N * P approximation.""" - return flops / (6 * tokens) - - -def _find_nearest_config(df: pd.DataFrame, flops: float, tokens: float) -> _NearestConfig: - """Find the nearest actual config from the dataframe to use as template.""" - sub = df[df.flops == flops] - if sub.empty: - sub = df - idx = (sub.tokens - tokens).abs().argmin() - row = sub.iloc[idx] - - run_name = row["name"] - meta = parse_isoflop_run_name(run_name) - - return { - "hidden_dim": int(meta["d"]) if meta else 0, - "num_layers": int(meta["L"]) if meta else 0, - "batch_size": int(meta["B"]) if meta else 0, - "params": float(row.get("params", _compute_optimal_params(flops, tokens))), - } - - # ---------------- Core Analysis ---------------- @@ -742,13 +711,14 @@ def fit_scaling_laws( if a == 0: continue - # Compute minimum L_opt = -b / (2 * a) N_star = float(10**L_opt) loss_opt = float(a * L_opt**2 + b * L_opt + c) - # Find nearest actual config for template - nearest = _find_nearest_config(sub, C, N_star) + idx = (sub.tokens - N_star).abs().argmin() + nearest_row = sub.iloc[idx] + run_name = nearest_row["name"] + meta = parse_isoflop_run_name(run_name) minima_records.append( { @@ -756,10 +726,10 @@ def fit_scaling_laws( "flops": float(C), "optimal_tokens": N_star, "loss_at_optimal": loss_opt, - "hidden_dim": nearest["hidden_dim"], - "num_layers": nearest["num_layers"], - "batch_size": nearest["batch_size"], - "optimal_params": float(nearest["params"]), + "hidden_dim": int(meta["d"]) if meta else 0, + "num_layers": int(meta["L"]) if meta else 0, + "batch_size": int(meta["B"]) if meta else 0, + "optimal_params": float(nearest_row.get("params", C / (6 * N_star))), } ) @@ -820,8 +790,6 @@ def transform_metrics_for_isoflop( for _, row in final_metrics.iterrows(): run_path = row["run_path"] run_name = extract_run_name_from_path(run_path) - - # Parse metadata from run name meta = parse_isoflop_run_name(run_name) if meta is None: logger.warning(f"Could not parse metadata from run name: {run_name}") @@ -904,18 +872,15 @@ def predict_optimal_config( logger.info(f"Predicted optimal tokens for {target_flops:.2e} FLOPs: {optimal_tokens:.2e}") - # Use default config if none provided if sweep_config is None: sweep_config = IsoFlopSweepConfig() - # Generate candidates for this budget candidates = list(candidate_configs(sweep_config, target_flops, vocab_size)) if not candidates: logger.warning(f"No valid candidates found for budget {target_flops:.2e}") return None - # Find candidate closest to optimal token count best = min(candidates, key=lambda c: abs(c.tokens - optimal_tokens)) logger.info( @@ -1023,17 +988,13 @@ class IsoFlopAnalysisConfig(EvalMetricsAnalysisConfig): def _run_isoflop_analysis_step(config: IsoFlopAnalysisConfig) -> None: """Execute scaling ladder analysis (called by ExecutorStep).""" - # Read metrics from training runs raw_df = read_metrics_dataframe(config) if raw_df.empty: logger.warning("No eval metrics found in training runs") return - # Convert label_map tuple to dict if provided label_map = dict(config.label_map) if config.label_map else None - - # Transform to isoflop analysis format isoflop_df = transform_metrics_for_isoflop(raw_df, config.metric_key, label_map) if isoflop_df.empty: @@ -1044,14 +1005,12 @@ def _run_isoflop_analysis_step(config: IsoFlopAnalysisConfig) -> None: logger.info(f"Labels found: {isoflop_df['label'].unique().tolist()}") logger.info(f"FLOP budgets: {sorted(isoflop_df['flops'].unique())}") - # Fit scaling laws minima_records, scaling_fits, fit_curves = fit_scaling_laws(isoflop_df) logger.info(f"Found {len(minima_records)} optimal configurations") for label, (alpha, A) in scaling_fits.items(): logger.info(f" {label}: N* = {A:.2e} * C^{alpha:.3f}") - # Convert minima to CandidateConfigs using shared helper configs = _minima_to_candidates(minima_records) result = IsoFlopAnalysisResult( @@ -1062,17 +1021,14 @@ def _run_isoflop_analysis_step(config: IsoFlopAnalysisConfig) -> None: fit_curves=fit_curves, ) - # Save outputs fs, _, _ = fsspec.get_fs_token_paths(config.output_path) fs.makedirs(config.output_path, exist_ok=True) - # Save result JSON result_path = os.path.join(config.output_path, "isoflop_analysis_result.json") with fs.open(result_path, "w") as f: json.dump(result.to_json_dict(), f, indent=2) logger.info(f"Saved results to {result_path}") - # Save plots if enabled if config.save_plots: from marin.scaling_laws.scaling_plots import ( create_isoflop_plot, @@ -1084,7 +1040,6 @@ def _run_isoflop_analysis_step(config: IsoFlopAnalysisConfig) -> None: fig_scaling = create_scaling_plot(minima_records, scaling_fits) save_plots(fig_isoflop, fig_scaling, config.output_path) - # Upload to WandB if enabled if config.upload_to_wandb: from marin.scaling_laws.scaling_plots import upload_plots_to_wandb @@ -1184,15 +1139,8 @@ def run_isoflop_analysis( Returns: IsoFlopAnalysisResult with configs, scaling_fits, and analysis data """ - # Convert to paths - run_paths = [] - for run in training_runs: - if isinstance(run, ExecutorStep): - run_paths.append(output_path_of(run)) - else: - run_paths.append(run) + run_paths = [output_path_of(run) if isinstance(run, ExecutorStep) else run for run in training_runs] - # Read metrics config = EvalMetricsAnalysisConfig( training_runs=run_paths, output_path="analysis/scaling_ladder", @@ -1209,7 +1157,6 @@ def run_isoflop_analysis( fit_curves={}, ) - # Transform to isoflop format isoflop_df = transform_metrics_for_isoflop(raw_df, metric_key, label_map) if isoflop_df.empty: @@ -1224,10 +1171,7 @@ def run_isoflop_analysis( logger.info(f"Transformed {len(isoflop_df)} runs for scaling ladder analysis") - # Fit scaling laws and extract optima minima_records, scaling_fits, fit_curves = fit_scaling_laws(isoflop_df) - - # Convert minima records to CandidateConfig objects using shared helper configs = _minima_to_candidates(minima_records) return IsoFlopAnalysisResult( diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py index 4760e8c414..94fd21421c 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -145,26 +145,20 @@ class ScalingLadderRungConfig: """Optional sweep config for predict_optimal_config. Uses defaults if None.""" -def load_scaling_fits(analysis_path: str) -> dict[str, tuple[float, float]]: - """Load scaling fits from an IsoFLOP analysis output.""" - result_path = os.path.join(analysis_path, "isoflop_analysis_result.json") +def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: + """Run one rung of the scaling ladder (one compute-optimal training run).""" + result_path = os.path.join(config.analysis_output_path, "isoflop_analysis_result.json") fs, _, _ = fsspec.get_fs_token_paths(result_path) with fs.open(result_path, "r") as f: - result = json.load(f) + analysis_result = json.load(f) - # Convert lists back to tuples scaling_fits: dict[str, tuple[float, float]] = {} - for key, value in result["scaling_fits"].items(): + for key, value in analysis_result["scaling_fits"].items(): if len(value) != 2: raise ValueError(f"Expected 2 scaling fit values for '{key}', got {len(value)}") scaling_fits[key] = (float(value[0]), float(value[1])) - return scaling_fits - -def get_optimal_candidate(config: ScalingLadderRungConfig) -> CandidateConfig: - """Load scaling fits and predict optimal config for target budget.""" - scaling_fits = load_scaling_fits(config.analysis_output_path) vocab_size = get_vocab_size_for_tokenizer(config.tokenizer) candidate = predict_optimal_config( @@ -177,24 +171,9 @@ def get_optimal_candidate(config: ScalingLadderRungConfig) -> CandidateConfig: if candidate is None: raise RuntimeError( - f"Could not find optimal config for budget {config.target_budget:.2e} " f"and label '{config.label}'" + f"Could not find optimal config for budget {config.target_budget:.2e} and label '{config.label}'" ) - return candidate - - -def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: - """Run one rung of the scaling ladder (one compute-optimal training run). - - This function: - 1. Loads scaling fits from the analysis output - 2. Predicts the optimal config for the target budget - 3. Trains a model with that config using the same infrastructure as default_train - """ - # Get the optimal candidate config from analysis - candidate = get_optimal_candidate(config) - vocab_size = get_vocab_size_for_tokenizer(config.tokenizer) - logger.info( f"Training with optimal config for {config.target_budget:.2e} FLOPs:\n" f" hidden_size={candidate.hidden_size}, num_layers={candidate.num_layers}\n" @@ -202,10 +181,8 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: f" learning_rate={candidate.learning_rate:.6f}, tokens={candidate.tokens:.2e}" ) - # Build model config using shared builder model_cfg = build_model_config(candidate, config.seq_len) - # Pick TPU type based on memory requirements param_count = model_cfg.total_trainable_params(vocab_size) tpu_type = pick_v5p_type( param_count, @@ -216,15 +193,12 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: vocab_size, ) - # Build optimizer config using shared builder optimizer_cfg = build_optimizer_config(candidate) - # Prepare data config - # Accepts both string paths and LMMixtureDatasetConfig - # If validation sets are needed, they should be pre-configured in the LMMixtureDatasetConfig + # Accepts both string paths and LMMixtureDatasetConfig. + # If validation sets are needed, they should be pre-configured in the LMMixtureDatasetConfig. pretraining_data = _prepare_data_config(config.tokenized) - # Build TrainLmConfig (mirrors default_train structure) train_config = TrainLmConfig( data=pretraining_data, trainer=TrainerConfig( @@ -252,7 +226,6 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: optimizer=optimizer_cfg, ) - # Build pod config and run training full_config = TrainLmOnPodConfig( train_config=train_config, resources=ResourceConfig.with_tpu(tpu_type), @@ -292,14 +265,11 @@ def scaling_ladder_rung_step( Returns: ExecutorStep configured to run one optimal training run """ - # Resolve tokenized data - works like default_train if isinstance(tokenized, ExecutorStep): resolved_tokenized: InputName | str | LMMixtureDatasetConfig = output_path_of(tokenized) elif isinstance(tokenized, LMMixtureDatasetConfig): - # Pass through LMMixtureDatasetConfig directly resolved_tokenized = tokenized else: - # InputName or string path resolved_tokenized = tokenized output_path = override_output_path if override_output_path is not None else this_output_path() @@ -407,7 +377,6 @@ def scaling_ladder_suite( """ from marin.scaling_laws.isoflop_analysis import isoflop_analysis_step - # Create the IsoFLOP analysis step analysis = isoflop_analysis_step( name=f"{name}-analysis", training_runs=training_runs, @@ -420,7 +389,6 @@ def scaling_ladder_suite( wandb_run_name=f"{name}-analysis", ) - # Create scaling ladder rungs (optimal training steps) for each target budget optimal_runs = [] for budget in target_budgets: run_step = scaling_ladder_rung_step( From e3f961be08b0b528f87c6445ca9ed1d212f171d3 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 22 Dec 2025 08:20:22 -0800 Subject: [PATCH 13/79] Lint --- experiments/defaults.py | 2 - lib/marin/src/marin/scaling_laws/__init__.py | 25 +++---- .../src/marin/scaling_laws/scaling_ladder.py | 1 - .../src/marin/scaling_laws/scaling_plots.py | 2 +- .../migrations/migrate_isoflop_wandb_runs.py | 75 +++++++++++-------- tests/test_scaling_laws.py | 3 +- 6 files changed, 55 insertions(+), 53 deletions(-) diff --git a/experiments/defaults.py b/experiments/defaults.py index fa5b8ebb10..cfd7f78255 100644 --- a/experiments/defaults.py +++ b/experiments/defaults.py @@ -46,7 +46,6 @@ CORE_TASKS, MMLU_TASKS, convert_to_levanter_task_config, - convert_to_task_metrics, ) from experiments.llama import compute_num_parameters, llama_8b from experiments.paloma import paloma_tokenized @@ -59,7 +58,6 @@ InputName, VersionedValue, ensure_versioned, - get_executor_step, this_output_path, unwrap_versioned_value, ) diff --git a/lib/marin/src/marin/scaling_laws/__init__.py b/lib/marin/src/marin/scaling_laws/__init__.py index c1bb789643..63a03a858b 100644 --- a/lib/marin/src/marin/scaling_laws/__init__.py +++ b/lib/marin/src/marin/scaling_laws/__init__.py @@ -44,35 +44,28 @@ ) __all__ = [ - # Primary interface (ExecutorStep factories) - "isoflop_analysis_step", - "scaling_ladder_suite", - "scaling_ladder_rung_step", - # Programmatic interface - "run_isoflop_analysis", - "generate_isoflop_train_args", - # Dataclasses + "DEFAULT_BUDGETS", "CandidateConfig", "IsoFlopAnalysisConfig", "IsoFlopAnalysisResult", "IsoFlopSweepConfig", "IsoFlopTrainArgs", - "ScalingLadderSuite", "ScalingLadderRungConfig", - # Constants - "DEFAULT_BUDGETS", - # Shared builders + "ScalingLadderSuite", "build_model_config", "build_optimizer_config", - # Utilities "candidate_configs", "compute_transformer_params", + "create_isoflop_plot", + "create_scaling_plot", + "generate_isoflop_train_args", + "isoflop_analysis_step", "pick_v5p_type", "predict_optimal_config", "predict_optimal_configs_for_budgets", - # Plotting functions - "create_isoflop_plot", - "create_scaling_plot", + "run_isoflop_analysis", "save_plots", + "scaling_ladder_rung_step", + "scaling_ladder_suite", "upload_plots_to_wandb", ] diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py index 94fd21421c..8881e9522e 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -58,7 +58,6 @@ from marin.processing.tokenize.data_configs import add_validation_sets_to_mixture, lm_data_config from marin.processing.tokenize.tokenize import TokenizeConfig from marin.scaling_laws.isoflop_analysis import ( - CandidateConfig, IsoFlopSweepConfig, build_model_config, build_optimizer_config, diff --git a/lib/marin/src/marin/scaling_laws/scaling_plots.py b/lib/marin/src/marin/scaling_laws/scaling_plots.py index a3a257388d..d0240e12f3 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_plots.py +++ b/lib/marin/src/marin/scaling_laws/scaling_plots.py @@ -250,7 +250,7 @@ def create_scaling_plot( y=list(map(float, N_fit)), mode="lines", line=dict(color=color, dash=dash, width=_SCALE_LINE["width"]), - name=f"{lab} fit (α={alpha:.3f})", + name=f"{lab} fit (a={alpha:.3f})", legendgroup=lab, ) ) diff --git a/scripts/migrations/migrate_isoflop_wandb_runs.py b/scripts/migrations/migrate_isoflop_wandb_runs.py index 94df5bf55a..35cab22a0a 100644 --- a/scripts/migrations/migrate_isoflop_wandb_runs.py +++ b/scripts/migrations/migrate_isoflop_wandb_runs.py @@ -1,4 +1,18 @@ #!/usr/bin/env python3 +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ Migrate WandB isoflop runs to match migrated checkpoint paths. @@ -15,7 +29,6 @@ import logging import re import sys -from typing import Optional try: import wandb @@ -40,7 +53,7 @@ def copy_wandb_run( ) -> bool: """ Copy a WandB run to a new run with a different name. - + Args: api: WandB API instance source_run: The source run to copy @@ -48,14 +61,14 @@ def copy_wandb_run( entity: WandB entity project: WandB project dry_run: If True, don't actually create the run - + Returns: True if successful (or would be successful in dry run) """ if dry_run: logger.info(f" [DRY RUN] Would copy {source_run.name} -> {new_name}") return True - + try: # Initialize a new run with the clean name new_run = wandb.init( @@ -67,16 +80,16 @@ def copy_wandb_run( config=dict(source_run.config), tags=list(source_run.tags), ) - + # Copy summary metrics summary = dict(source_run.summary) for key, value in summary.items(): new_run.summary[key] = value - + logger.info(f" Created new run: {new_name}") new_run.finish() return True - + except Exception as e: logger.error(f" Failed to copy run {source_run.name}: {e}") return False @@ -84,12 +97,12 @@ def copy_wandb_run( def migrate_isoflop_wandb_runs( entity_project: str, - run_name_filter: Optional[str] = None, + run_name_filter: str | None = None, dry_run: bool = True, ) -> None: """ Migrate WandB isoflop runs by copying them without hash suffixes. - + Args: entity_project: WandB entity/project (format: 'entity/project') run_name_filter: Optional filter to only process specific runs @@ -97,65 +110,65 @@ def migrate_isoflop_wandb_runs( """ if "/" not in entity_project: raise ValueError(f"Invalid entity_project format: {entity_project}. Expected 'entity/project'") - + entity, project = entity_project.split("/", 1) api = wandb.Api() - + logger.info(f"Querying WandB for isoflop runs in {entity_project}...") - + # Query for isoflop runs with hash suffixes filters = { "displayName": {"$regex": "isoflop"}, "state": "finished", } - + runs = api.runs(entity_project, filters=filters) - + migrated_count = 0 skipped_count = 0 error_count = 0 - + for run in runs: display_name = run.displayName - + # Check if this run has a hash suffix if not re.search(r"-[0-9a-fA-F]{6}$", display_name): logger.debug(f"Skipping {display_name} (no hash suffix)") skipped_count += 1 continue - + # Strip the hash to get the clean name clean_name = re.sub(r"-[0-9a-fA-F]{6}$", "", display_name) - + # Apply filter if specified if run_name_filter and run_name_filter not in clean_name: logger.debug(f"Skipping {display_name} (doesn't match filter)") skipped_count += 1 continue - + # Check if a run with the clean name already exists try: - existing = api.run(f"{entity_project}/{clean_name}") + api.run(f"{entity_project}/{clean_name}") logger.info(f"Skipping {display_name} -> {clean_name} (already exists)") skipped_count += 1 continue except wandb.errors.CommError: # Run doesn't exist, we can create it pass - + logger.info(f"Processing: {display_name} -> {clean_name}") - + if copy_wandb_run(api, run, clean_name, entity, project, dry_run): migrated_count += 1 else: error_count += 1 - + logger.info("\n" + "=" * 60) logger.info("Migration Summary:") logger.info(f" Migrated: {migrated_count}") logger.info(f" Skipped: {skipped_count}") logger.info(f" Errors: {error_count}") - + if dry_run: logger.info("\nDry run complete. Run with --execute to perform the migration.") @@ -176,34 +189,34 @@ def main(): python migrate_isoflop_wandb_runs.py marin-community/marin --filter nemo --execute """, ) - + parser.add_argument( "entity_project", help="WandB entity/project (format: entity/project)", ) - + parser.add_argument( "--execute", action="store_true", help="Actually perform the migration (default is dry run)", ) - + parser.add_argument( "--filter", help="Only process runs whose clean name contains this string", ) - + parser.add_argument( "--verbose", action="store_true", help="Enable verbose logging", ) - + args = parser.parse_args() - + if args.verbose: logging.getLogger().setLevel(logging.DEBUG) - + try: migrate_isoflop_wandb_runs( entity_project=args.entity_project, diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py index 373ad56815..5f6380f84a 100644 --- a/tests/test_scaling_laws.py +++ b/tests/test_scaling_laws.py @@ -39,7 +39,6 @@ round_to_power_of_two, ) - # --- round_to_power_of_two tests --- @@ -249,7 +248,7 @@ def test_robust_quad_logx_handles_noise(): noise = jnp.array([0.01, -0.02, 0.015, -0.01, 0.005]) y = y_clean + noise - a, b, c = robust_quad_logx(x, y) + a, b, _ = robust_quad_logx(x, y) assert abs(a - 0.05) < 0.05 assert abs(b - (-1.5)) < 0.5 From 32fb07418bad96d08014224b40ec80c3e882dde8 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 22 Dec 2025 08:24:33 -0800 Subject: [PATCH 14/79] Tweak --- lib/marin/src/marin/scaling_laws/eval_metrics_reader.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py index 3b05eae7b5..df48b694c4 100644 --- a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py +++ b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py @@ -15,8 +15,10 @@ """Base infrastructure for eval metrics analysis. This module provides a base config and utilities for analysis jobs that -read eval_metrics.jsonl files from completed training runs. Specific -analyses (like IsoFlop) should subclass EvalMetricsAnalysisConfig. +read eval_metrics.jsonl files from completed training runs. The subclassing +pattern mirrors the Evaluator approach in +lib/marin/src/marin/evaluation/evaluators/evaluator.py, so specific analyses +(like IsoFlop) should subclass EvalMetricsAnalysisConfig. """ import logging From 6371fa4f2fffa67c44bc1f6389591984eb0f5626 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 22 Dec 2025 08:35:45 -0800 Subject: [PATCH 15/79] Fix Caller --- experiments/exp1603_subgroup_evals.py | 31 +++++++++++++++++++-------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/experiments/exp1603_subgroup_evals.py b/experiments/exp1603_subgroup_evals.py index 26066211f8..872d397fe0 100644 --- a/experiments/exp1603_subgroup_evals.py +++ b/experiments/exp1603_subgroup_evals.py @@ -24,6 +24,7 @@ from experiments.models import ModelConfig, download_model_step from marin.execution.executor import executor_main, output_path_of, versioned from marin.evaluation.log_probs import default_lm_log_probs +from marin.scaling_laws.isoflop_analysis import build_model_config # This is painfully slow to run in dry run mode # nodryrun @@ -40,8 +41,12 @@ def create_eval_steps() -> list: steps = [] dist_eval = distributional_eval_sets(llama3_tokenizer) - for model, metadata in list(zip(*MARIN_SCALING_SUITES["nemotron"], strict=False)): - name = f"marin-nemo-{metadata[0]}C-{metadata[-3] * metadata[-2] * 4096}T-{metadata[1]}W-{metadata[2]}D" + for model, candidate in list(zip(*MARIN_SCALING_SUITES["nemotron"], strict=False)): + total_tokens = candidate.batch_size * candidate.train_steps * 4096 + name = ( + f"marin-nemo-{candidate.flops_budget:.0e}C-{total_tokens}T-" + f"{candidate.hidden_size}W-{candidate.num_layers}D" + ) step = evaluate_levanter_lm_evaluation_harness( model_name=name, @@ -53,7 +58,7 @@ def create_eval_steps() -> list: logprobs_step = default_lm_log_probs( output_path_of(model).cd("checkpoints"), - metadata[-1], + build_model_config(candidate), dist_eval, resource_config=ResourceConfig.with_tpu("v5p-8"), checkpoint_is_hf=False, @@ -62,8 +67,12 @@ def create_eval_steps() -> list: steps.append(logprobs_step) - for model, metadata in list(zip(*MARIN_SCALING_SUITES["common_pile"], strict=False)): - name = f"marin-comma-{metadata[0]}C-{metadata[-3] * metadata[-2] * 4096}T-{metadata[1]}W-{metadata[2]}D" + for model, candidate in list(zip(*MARIN_SCALING_SUITES["common_pile"], strict=False)): + total_tokens = candidate.batch_size * candidate.train_steps * 4096 + name = ( + f"marin-comma-{candidate.flops_budget:.0e}C-{total_tokens}T-" + f"{candidate.hidden_size}W-{candidate.num_layers}D" + ) step = evaluate_levanter_lm_evaluation_harness( model_name=name, @@ -75,7 +84,7 @@ def create_eval_steps() -> list: logprobs_step = default_lm_log_probs( output_path_of(model).cd("checkpoints"), - metadata[-1], + build_model_config(candidate), dist_eval, resource_config=ResourceConfig.with_tpu("v5p-8"), checkpoint_is_hf=False, @@ -84,8 +93,12 @@ def create_eval_steps() -> list: steps.append(logprobs_step) - for model, metadata in list(zip(*MARIN_SCALING_SUITES["dclm-default"], strict=False)): - name = f"marin-dclm-{metadata[0]}C-{metadata[-3] * metadata[-2] * 4096}T-{metadata[1]}W-{metadata[2]}D" + for model, candidate in list(zip(*MARIN_SCALING_SUITES["dclm-default"], strict=False)): + total_tokens = candidate.batch_size * candidate.train_steps * 4096 + name = ( + f"marin-dclm-{candidate.flops_budget:.0e}C-{total_tokens}T-" + f"{candidate.hidden_size}W-{candidate.num_layers}D" + ) step = evaluate_levanter_lm_evaluation_harness( model_name=name, @@ -97,7 +110,7 @@ def create_eval_steps() -> list: logprobs_step = default_lm_log_probs( output_path_of(model).cd("checkpoints"), - metadata[-1], + build_model_config(candidate), dist_eval, resource_config=ResourceConfig.with_tpu("v5p-8"), checkpoint_is_hf=False, From 06179a3a4d5f1771b5a34a906e254cab9974b319 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 22 Dec 2025 08:40:51 -0800 Subject: [PATCH 16/79] I don't actually want this, migrating is simpler --- lib/levanter/src/levanter/eval.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/lib/levanter/src/levanter/eval.py b/lib/levanter/src/levanter/eval.py index 4b9e08b618..3eacb84d5b 100644 --- a/lib/levanter/src/levanter/eval.py +++ b/lib/levanter/src/levanter/eval.py @@ -234,10 +234,6 @@ def eval_callback(step: StepInfo): fs.makedirs(checkpoint_path, exist_ok=True) with fs.open(metrics_file, "a") as f: record = {"step": int(step_count), **metrics_to_write} - # Include WandB run info for backfill/lookup - wandb_info = levanter.tracker.current_tracker_info() - if wandb_info: - record["_tracker"] = wandb_info f.write(json.dumps(record) + "\n") return From 81ed7a01e8669cb52d7f2d04ef1a0c53bffa2c74 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 22 Dec 2025 12:10:25 -0800 Subject: [PATCH 17/79] First Refactor --- experiments/isoflop_sweep.py | 29 +++++--- lib/marin/pyproject.toml | 1 + lib/marin/src/marin/scaling_laws/__init__.py | 2 + .../marin/scaling_laws/isoflop_analysis.py | 70 ++++++++++--------- .../src/marin/scaling_laws/scaling_ladder.py | 8 +-- .../src/marin/scaling_laws/scaling_plots.py | 22 +++--- tests/test_scaling_laws.py | 40 ++++++++--- uv.lock | 2 + 8 files changed, 105 insertions(+), 69 deletions(-) diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index bc825dfa15..c198c8d758 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -45,16 +45,19 @@ @dataclass(frozen=True) -class IsoFlopExperimentConfig(IsoFlopSweepConfig): - """Extended config for isoflop experiments with dataset and eval settings. +class IsoFlopExperimentConfig: + """Configuration for isoflop experiments with dataset and eval settings. - Inherits core sweep parameters from IsoFlopSweepConfig and adds + Composes an IsoFlopSweepConfig for core sweep parameters and adds experiment-specific settings like tokenized dataset and eval tasks. """ - tokenized_dataset: InputName | str = "" + tokenized_dataset: InputName | str """Tokenized dataset to train on.""" + sweep_config: IsoFlopSweepConfig = dataclasses.field(default_factory=IsoFlopSweepConfig) + """Core sweep parameters (budgets, seq_len, etc.).""" + eval_tasks: tuple[EvalTaskConfig, ...] | None = None """Evaluation tasks to run after training (disabled by default).""" @@ -104,11 +107,11 @@ def generate_isoflop_steps( - candidates: CandidateConfig for each training run (contains budget, hidden_size, num_layers, batch_size, train_steps, learning_rate, etc.) """ - vocab_size = get_vocab_size_for_tokenizer(config.tokenizer) + vocab_size = get_vocab_size_for_tokenizer(config.sweep_config.tokenizer) # Get training arguments from the library train_args_list = generate_isoflop_train_args( - sweep_config=config, + sweep_config=config.sweep_config, experiment_name=experiment_name, vocab_size=vocab_size, base_optimizer_config=config.base_optimizer_config, @@ -160,22 +163,28 @@ def generate_isoflop_steps( def generate_isoflop_sweep( tokenized: InputName | ExecutorStep | LMMixtureDatasetConfig, experiment_name: str, - **kwargs, + sweep_config: IsoFlopSweepConfig | None = None, + eval_tasks: tuple[EvalTaskConfig, ...] | None = None, ) -> tuple[list[ExecutorStep], list[CandidateConfig]]: """Generate an ISOFlop sweep for a tokenized dataset. Args: tokenized: Tokenized dataset to train on. experiment_name: Name suffix for the experiment (e.g., 'nemo', 'dclm'). - **kwargs: Additional arguments passed to IsoFlopExperimentConfig. + sweep_config: Optional custom sweep config. Uses defaults if None. + eval_tasks: Optional evaluation tasks to run after training. Returns: A tuple of: - steps: Training and evaluation ExecutorSteps for the sweep. - candidates: CandidateConfig for each training run with full config details. """ - sweep_cfg = IsoFlopExperimentConfig(tokenized_dataset=tokenized, **kwargs) - steps, candidates = generate_isoflop_steps(sweep_cfg, experiment_name) + config = IsoFlopExperimentConfig( + tokenized_dataset=tokenized, + sweep_config=sweep_config or IsoFlopSweepConfig(), + eval_tasks=eval_tasks, + ) + steps, candidates = generate_isoflop_steps(config, experiment_name) return steps, candidates diff --git a/lib/marin/pyproject.toml b/lib/marin/pyproject.toml index 0548f5e392..03be869635 100644 --- a/lib/marin/pyproject.toml +++ b/lib/marin/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "numpy", "openai", "pandas", + "plotly", "pyarrow>=22", "ray==2.53.0", "regex", diff --git a/lib/marin/src/marin/scaling_laws/__init__.py b/lib/marin/src/marin/scaling_laws/__init__.py index 63a03a858b..83ae5adb21 100644 --- a/lib/marin/src/marin/scaling_laws/__init__.py +++ b/lib/marin/src/marin/scaling_laws/__init__.py @@ -19,6 +19,7 @@ IsoFlopAnalysisResult, IsoFlopSweepConfig, IsoFlopTrainArgs, + MinimaRecord, build_model_config, build_optimizer_config, candidate_configs, @@ -50,6 +51,7 @@ "IsoFlopAnalysisResult", "IsoFlopSweepConfig", "IsoFlopTrainArgs", + "MinimaRecord", "ScalingLadderRungConfig", "ScalingLadderSuite", "build_model_config", diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index ef6352de54..5304f21a2d 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -37,8 +37,7 @@ import os import re from collections.abc import Iterator, Sequence -from dataclasses import asdict, dataclass, replace -from typing import NotRequired, TypedDict +from dataclasses import asdict, dataclass, field, replace import fsspec import jax.numpy as jnp @@ -187,7 +186,10 @@ class IsoFlopTrainArgs: # ---------------- Typed Records ---------------- -class MinimaRecord(TypedDict): +@dataclass +class MinimaRecord: + """Record of optimal configuration found at a specific (label, flops) point.""" + label: str flops: float optimal_tokens: float @@ -196,8 +198,8 @@ class MinimaRecord(TypedDict): num_layers: int batch_size: int optimal_params: float - scaling_alpha: NotRequired[float] - scaling_A: NotRequired[float] + scaling_alpha: float | None = None + scaling_A: float | None = None # ---------------- Candidate Config Generation ---------------- @@ -483,21 +485,21 @@ def _minima_to_candidates(minima_records: list[MinimaRecord]) -> list[CandidateC """ configs = [] for rec in minima_records: - if rec["hidden_dim"] == 0: + if rec.hidden_dim == 0: continue configs.append( CandidateConfig( - hidden_size=rec["hidden_dim"], - intermediate_dim=rec["hidden_dim"] * MLP_RATIO, - num_layers=rec["num_layers"], - num_heads=max(1, rec["hidden_dim"] // HIDDEN_HEAD_RATIO), - num_kv_heads=max(1, rec["hidden_dim"] // HIDDEN_HEAD_RATIO), - batch_size=rec["batch_size"], - train_steps=int(rec["optimal_tokens"] / (rec["batch_size"] * SEQ_LEN)), - learning_rate=(LR_CONSTANT * math.sqrt(rec["batch_size"])) / rec["hidden_dim"], - beta2=BETA2_BASE ** (rec["batch_size"] / BETA2_BATCH_DIVISOR), - tokens=rec["optimal_tokens"], - flops_budget=rec["flops"], + hidden_size=rec.hidden_dim, + intermediate_dim=rec.hidden_dim * MLP_RATIO, + num_layers=rec.num_layers, + num_heads=max(1, rec.hidden_dim // HIDDEN_HEAD_RATIO), + num_kv_heads=max(1, rec.hidden_dim // HIDDEN_HEAD_RATIO), + batch_size=rec.batch_size, + train_steps=int(rec.optimal_tokens / (rec.batch_size * SEQ_LEN)), + learning_rate=(LR_CONSTANT * math.sqrt(rec.batch_size)) / rec.hidden_dim, + beta2=BETA2_BASE ** (rec.batch_size / BETA2_BATCH_DIVISOR), + tokens=rec.optimal_tokens, + flops_budget=rec.flops, ) ) return configs @@ -721,32 +723,32 @@ def fit_scaling_laws( meta = parse_isoflop_run_name(run_name) minima_records.append( - { - "label": lab, - "flops": float(C), - "optimal_tokens": N_star, - "loss_at_optimal": loss_opt, - "hidden_dim": int(meta["d"]) if meta else 0, - "num_layers": int(meta["L"]) if meta else 0, - "batch_size": int(meta["B"]) if meta else 0, - "optimal_params": float(nearest_row.get("params", C / (6 * N_star))), - } + MinimaRecord( + label=lab, + flops=float(C), + optimal_tokens=N_star, + loss_at_optimal=loss_opt, + hidden_dim=int(meta["d"]) if meta else 0, + num_layers=int(meta["L"]) if meta else 0, + batch_size=int(meta["B"]) if meta else 0, + optimal_params=float(nearest_row.get("params", C / (6 * N_star))), + ) ) # Fit scaling law N* ~ A * C^alpha per dataset scaling_fits: dict[str, tuple[float, float]] = {} by_lab: dict[str, list[MinimaRecord]] = {} for rec in minima_records: - by_lab.setdefault(rec["label"], []).append(rec) + by_lab.setdefault(rec.label, []).append(rec) for lab in datasets: recs = by_lab.get(lab, []) if len(recs) < 2: continue - recs = sorted(recs, key=lambda r: r["flops"]) - Cs = jnp.array([r["flops"] for r in recs]) - Ns = jnp.array([r["optimal_tokens"] for r in recs]) + recs = sorted(recs, key=lambda r: r.flops) + Cs = jnp.array([r.flops for r in recs]) + Ns = jnp.array([r.optimal_tokens for r in recs]) alpha, logA = jnp.polyfit(jnp.log10(Cs), jnp.log10(Ns), 1) A = float(10**logA) @@ -755,8 +757,8 @@ def fit_scaling_laws( # Augment minima records with scaling fit params for rec in recs: - rec["scaling_alpha"] = alpha - rec["scaling_A"] = A + rec.scaling_alpha = alpha + rec.scaling_A = A return minima_records, scaling_fits, fit_curves @@ -953,7 +955,7 @@ def to_json_dict(self) -> dict: return { "configs": [asdict(c) for c in self.configs], "scaling_fits": {k: list(v) for k, v in self.scaling_fits.items()}, - "minima_records": self.minima_records, + "minima_records": [asdict(r) for r in self.minima_records], } diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py index 8881e9522e..f03eda2d9f 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -18,8 +18,7 @@ configurations derived from IsoFLOP analysis. Usage: - from marin.scaling_laws import isoflop_analysis_step - from marin.scaling_laws.scaling_ladder import scaling_ladder_rung_step + from marin.scaling_laws import isoflop_analysis_step, scaling_ladder_rung_step # First, run IsoFLOP analysis analysis = isoflop_analysis_step( @@ -33,7 +32,7 @@ analysis_step=analysis, target_budget=1e21, label="nemo", - dataset=my_tokenized_dataset, + tokenized=my_tokenized_dataset, ) """ @@ -61,6 +60,7 @@ IsoFlopSweepConfig, build_model_config, build_optimizer_config, + isoflop_analysis_step, pick_v5p_type, predict_optimal_config, ) @@ -374,8 +374,6 @@ def scaling_ladder_suite( ... ) >>> all_steps = [*isoflop_training_steps, *suite.all_steps] """ - from marin.scaling_laws.isoflop_analysis import isoflop_analysis_step - analysis = isoflop_analysis_step( name=f"{name}-analysis", training_runs=training_runs, diff --git a/lib/marin/src/marin/scaling_laws/scaling_plots.py b/lib/marin/src/marin/scaling_laws/scaling_plots.py index d0240e12f3..ec8512f0d5 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_plots.py +++ b/lib/marin/src/marin/scaling_laws/scaling_plots.py @@ -86,7 +86,7 @@ def create_isoflop_plot( df: pd.DataFrame, - minima_records: list[dict], + minima_records: list, fit_curves: dict[tuple[str, float], tuple[float, float, float]], ) -> go.Figure: """Create the IsoFLOP plot showing loss vs tokens for each compute budget. @@ -111,7 +111,7 @@ def create_isoflop_plot( fig = go.Figure() # Build lookup for minima - minima_lookup = {(rec["label"], rec["flops"]): rec for rec in minima_records} + minima_lookup = {(rec.label, rec.flops): rec for rec in minima_records} for lab in datasets: for C in buckets: @@ -159,8 +159,8 @@ def create_isoflop_plot( rec = minima_lookup[key] fig.add_trace( go.Scatter( - x=[rec["optimal_tokens"]], - y=[rec["loss_at_optimal"]], + x=[rec.optimal_tokens], + y=[rec.loss_at_optimal], mode="markers", marker=_MIN_MARKER, showlegend=False, @@ -171,7 +171,7 @@ def create_isoflop_plot( "loss=%{y:.4f}
params=%{customdata:.3e}" ), text=[C], - customdata=[rec["optimal_params"]], + customdata=[rec.optimal_params], ) ) @@ -189,13 +189,13 @@ def create_isoflop_plot( def create_scaling_plot( - minima_records: list[dict], + minima_records: list, scaling_fits: dict[str, tuple[float, float]], ) -> go.Figure: """Create the scaling law fit plot showing N* vs compute budget. Args: - minima_records: List of dicts with optimal config info per (label, flops) + minima_records: List of MinimaRecord with optimal config info per (label, flops) scaling_fits: Dict of {label: (alpha, A)} for N* ~ A * C^alpha Returns: @@ -207,7 +207,7 @@ def create_scaling_plot( # Group by label by_lab = {} for rec in minima_records: - by_lab.setdefault(rec["label"], []).append(rec) + by_lab.setdefault(rec.label, []).append(rec) datasets = list(by_lab.keys()) @@ -218,9 +218,9 @@ def create_scaling_plot( if not recs: continue - recs = sorted(recs, key=lambda r: r["flops"]) - Cs = jnp.array([r["flops"] for r in recs]) - Ns = jnp.array([r["optimal_tokens"] for r in recs]) + recs = sorted(recs, key=lambda r: r.flops) + Cs = jnp.array([r.flops for r in recs]) + Ns = jnp.array([r.optimal_tokens for r in recs]) color = PALETTE[i % len(PALETTE)] dash = DASHES[i % len(DASHES)] diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py index 5f6380f84a..270412eb6a 100644 --- a/tests/test_scaling_laws.py +++ b/tests/test_scaling_laws.py @@ -29,6 +29,7 @@ CandidateConfig, IsoFlopSweepConfig, IsoFlopTrainArgs, + MinimaRecord, candidate_configs, compute_total_flops, compute_transformer_params, @@ -463,13 +464,16 @@ def test_create_isoflop_plot_with_data(): } ) minima_records = [ - { - "label": "nemo", - "flops": 1e18, - "optimal_tokens": 2e9, - "loss_at_optimal": 2.3, - "optimal_params": 1e8, - } + MinimaRecord( + label="nemo", + flops=1e18, + optimal_tokens=2e9, + loss_at_optimal=2.3, + hidden_dim=512, + num_layers=8, + batch_size=64, + optimal_params=1e8, + ) ] fit_curves = {("nemo", 1e18): (0.1, -1.0, 3.0)} fig = create_isoflop_plot(df, minima_records, fit_curves) @@ -489,8 +493,26 @@ def test_create_scaling_plot_with_data(): from marin.scaling_laws import create_scaling_plot minima_records = [ - {"label": "nemo", "flops": 1e18, "optimal_tokens": 1e9}, - {"label": "nemo", "flops": 1e19, "optimal_tokens": 5e9}, + MinimaRecord( + label="nemo", + flops=1e18, + optimal_tokens=1e9, + loss_at_optimal=2.3, + hidden_dim=512, + num_layers=8, + batch_size=64, + optimal_params=1e8, + ), + MinimaRecord( + label="nemo", + flops=1e19, + optimal_tokens=5e9, + loss_at_optimal=2.1, + hidden_dim=1024, + num_layers=16, + batch_size=128, + optimal_params=5e8, + ), ] scaling_fits = {"nemo": (0.5, 1e5)} fig = create_scaling_plot(minima_records, scaling_fits) diff --git a/uv.lock b/uv.lock index 236561e869..9520a6b681 100644 --- a/uv.lock +++ b/uv.lock @@ -4342,6 +4342,7 @@ dependencies = [ { name = "numpy" }, { name = "openai" }, { name = "pandas" }, + { name = "plotly" }, { name = "pyarrow" }, { name = "ray" }, { name = "regex" }, @@ -4509,6 +4510,7 @@ requires-dist = [ { name = "numpy" }, { name = "openai" }, { name = "pandas" }, + { name = "plotly" }, { name = "prime", marker = "extra == 'rl'" }, { name = "pyarrow", specifier = ">=22" }, { name = "pylatexenc", marker = "extra == 'math'" }, From d06e7f22ae90d6cddbdd41d742c0327c0480d7ef Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 22 Dec 2025 12:18:39 -0800 Subject: [PATCH 18/79] Fixes --- lib/levanter/src/levanter/eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/levanter/src/levanter/eval.py b/lib/levanter/src/levanter/eval.py index 3eacb84d5b..7407c1f631 100644 --- a/lib/levanter/src/levanter/eval.py +++ b/lib/levanter/src/levanter/eval.py @@ -234,7 +234,7 @@ def eval_callback(step: StepInfo): fs.makedirs(checkpoint_path, exist_ok=True) with fs.open(metrics_file, "a") as f: record = {"step": int(step_count), **metrics_to_write} - f.write(json.dumps(record) + "\n") + f.write(json.dumps(record, sort_keys=True) + "\n") return From aa62db9d3b220b90586b151b62e7209f00cd78d2 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 22 Dec 2025 12:33:53 -0800 Subject: [PATCH 19/79] Lint --- lib/marin/src/marin/scaling_laws/isoflop_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index 5304f21a2d..8c8d21214c 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -37,7 +37,7 @@ import os import re from collections.abc import Iterator, Sequence -from dataclasses import asdict, dataclass, field, replace +from dataclasses import asdict, dataclass, replace import fsspec import jax.numpy as jnp From 681e73c2c0f4184d6b8c207c5f4d76434d54ed44 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 22 Dec 2025 12:39:30 -0800 Subject: [PATCH 20/79] Full Run Scales --- experiments/exp2166_scaling_ladder_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experiments/exp2166_scaling_ladder_analysis.py b/experiments/exp2166_scaling_ladder_analysis.py index 962ad8981a..4df07a1bae 100644 --- a/experiments/exp2166_scaling_ladder_analysis.py +++ b/experiments/exp2166_scaling_ladder_analysis.py @@ -42,7 +42,7 @@ # Target budgets for compute-optimal training runs (beyond the isoflop sweep) # Set to empty list to only run analysis without training -TARGET_BUDGETS: list[float] = [] +TARGET_BUDGETS: list[float] = [1e18, 3e18, 6e18, 1e19, 3e19, 6e19, 1e20] nemotron_suite = scaling_ladder_suite( From 3039dd824f837faa3b068fce84707885bc069c4e Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 22 Dec 2025 16:05:50 -0800 Subject: [PATCH 21/79] Snapshot Test --- tests/test_scaling_laws.py | 113 +++++++++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py index 270412eb6a..9de6930af2 100644 --- a/tests/test_scaling_laws.py +++ b/tests/test_scaling_laws.py @@ -517,3 +517,116 @@ def test_create_scaling_plot_with_data(): scaling_fits = {"nemo": (0.5, 1e5)} fig = create_scaling_plot(minima_records, scaling_fits) assert fig is not None + + +# --- Snapshot tests --- + +# Snapshot of expected output for generate_isoflop_train_args with budget=1e18. +# This captures the configuration generation logic from experiments/isoflop_sweep.py +# on the main branch to ensure the refactored code produces identical configs. +EXPECTED_ISOFLOP_CONFIGS_1E18 = [ + { + "hidden_size": 512, + "intermediate_dim": 2048, + "num_layers": 6, + "num_heads": 4, + "num_kv_heads": 4, + "batch_size": 32, + "train_steps": 32844, + "learning_rate": 0.003646, + "beta2": 0.994962, + "tpu_type": "v5p-8", + "run_name": "isoflop-1e+18-d512-L6-B32-test-snapshot", + }, + { + "hidden_size": 640, + "intermediate_dim": 2560, + "num_layers": 7, + "num_heads": 5, + "num_kv_heads": 5, + "batch_size": 16, + "train_steps": 46274, + "learning_rate": 0.002063, + "beta2": 0.997478, + "tpu_type": "v5p-8", + "run_name": "isoflop-1e+18-d640-L7-B16-test-snapshot", + }, + { + "hidden_size": 768, + "intermediate_dim": 3072, + "num_layers": 8, + "num_heads": 6, + "num_kv_heads": 6, + "batch_size": 16, + "train_steps": 33965, + "learning_rate": 0.001719, + "beta2": 0.997478, + "tpu_type": "v5p-8", + "run_name": "isoflop-1e+18-d768-L8-B16-test-snapshot", + }, + { + "hidden_size": 896, + "intermediate_dim": 3584, + "num_layers": 10, + "num_heads": 7, + "num_kv_heads": 7, + "batch_size": 8, + "train_steps": 48105, + "learning_rate": 0.001042, + "beta2": 0.998738, + "tpu_type": "v5p-8", + "run_name": "isoflop-1e+18-d896-L10-B8-test-snapshot", + }, + { + "hidden_size": 1024, + "intermediate_dim": 4096, + "num_layers": 11, + "num_heads": 8, + "num_kv_heads": 8, + "batch_size": 8, + "train_steps": 37335, + "learning_rate": 0.000912, + "beta2": 0.998738, + "tpu_type": "v5p-8", + "run_name": "isoflop-1e+18-d1024-L11-B8-test-snapshot", + }, +] + + +def test_generate_isoflop_train_args_snapshot(): + """Snapshot test: verify generate_isoflop_train_args produces expected configs. + + This test ensures the refactored scaling_laws module produces identical + configurations to the original experiments/isoflop_sweep.py implementation. + """ + config = IsoFlopSweepConfig(budgets=(1e18,)) + result = generate_isoflop_train_args( + sweep_config=config, + experiment_name="test-snapshot", + vocab_size=MARIN_TOKENIZER_VOCAB_SIZE, + ) + + assert len(result) == len(EXPECTED_ISOFLOP_CONFIGS_1E18), ( + f"Expected {len(EXPECTED_ISOFLOP_CONFIGS_1E18)} configs, got {len(result)}" + ) + + for i, (args, expected) in enumerate(zip(result, EXPECTED_ISOFLOP_CONFIGS_1E18)): + c = args.candidate + actual = { + "hidden_size": c.hidden_size, + "intermediate_dim": c.intermediate_dim, + "num_layers": c.num_layers, + "num_heads": c.num_heads, + "num_kv_heads": c.num_kv_heads, + "batch_size": c.batch_size, + "train_steps": c.train_steps, + "learning_rate": round(c.learning_rate, 6), + "beta2": round(c.beta2, 6), + "tpu_type": args.tpu_type, + "run_name": args.run_name, + } + + for key in expected: + assert actual[key] == expected[key], ( + f"Config {i}: {key} mismatch: expected {expected[key]}, got {actual[key]}" + ) From b141510df74e7235a3285c9c69dabdd351a1411e Mon Sep 17 00:00:00 2001 From: Helw150 Date: Wed, 24 Dec 2025 07:50:16 -0800 Subject: [PATCH 22/79] Fix UV Sync --- uv.lock | 1 + 1 file changed, 1 insertion(+) diff --git a/uv.lock b/uv.lock index 9520a6b681..36130bc827 100644 --- a/uv.lock +++ b/uv.lock @@ -4500,6 +4500,7 @@ requires-dist = [ { name = "jax", marker = "extra == 'cpu'", specifier = "==0.8.0" }, { name = "jax", extras = ["cuda12"], marker = "extra == 'gpu'", specifier = "==0.8.0" }, { name = "jax", extras = ["tpu"], marker = "extra == 'tpu'", specifier = "==0.8.0" }, + { name = "jaxopt", specifier = ">=0.8.3" }, { name = "levanter", extras = ["serve"], editable = "lib/levanter" }, { name = "lm-eval", git = "https://github.com/stanford-crfm/lm-evaluation-harness?rev=d5e3391f22cde186c827674d5c3ec7c5f4fe0cab" }, { name = "lm-eval", extras = ["math"], marker = "extra == 'eval'", git = "https://github.com/stanford-crfm/lm-evaluation-harness?rev=d5e3391f22cde186c827674d5c3ec7c5f4fe0cab" }, From 8acbdb3782e17bd56cd67413e60269ff4fcce3dc Mon Sep 17 00:00:00 2001 From: Helw150 Date: Wed, 24 Dec 2025 08:17:59 -0800 Subject: [PATCH 23/79] Lint --- tests/test_scaling_laws.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py index 9de6930af2..b10f30053c 100644 --- a/tests/test_scaling_laws.py +++ b/tests/test_scaling_laws.py @@ -606,11 +606,11 @@ def test_generate_isoflop_train_args_snapshot(): vocab_size=MARIN_TOKENIZER_VOCAB_SIZE, ) - assert len(result) == len(EXPECTED_ISOFLOP_CONFIGS_1E18), ( - f"Expected {len(EXPECTED_ISOFLOP_CONFIGS_1E18)} configs, got {len(result)}" - ) + assert len(result) == len( + EXPECTED_ISOFLOP_CONFIGS_1E18 + ), f"Expected {len(EXPECTED_ISOFLOP_CONFIGS_1E18)} configs, got {len(result)}" - for i, (args, expected) in enumerate(zip(result, EXPECTED_ISOFLOP_CONFIGS_1E18)): + for i, (args, expected) in enumerate(zip(result, EXPECTED_ISOFLOP_CONFIGS_1E18, strict=True)): c = args.candidate actual = { "hidden_size": c.hidden_size, @@ -627,6 +627,6 @@ def test_generate_isoflop_train_args_snapshot(): } for key in expected: - assert actual[key] == expected[key], ( - f"Config {i}: {key} mismatch: expected {expected[key]}, got {actual[key]}" - ) + assert ( + actual[key] == expected[key] + ), f"Config {i}: {key} mismatch: expected {expected[key]}, got {actual[key]}" From 7ba8073200d61410913c8981c9d062ff674b55a6 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Fri, 26 Dec 2025 08:24:38 -0800 Subject: [PATCH 24/79] Just one --- experiments/exp2166_scaling_ladder_analysis.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/experiments/exp2166_scaling_ladder_analysis.py b/experiments/exp2166_scaling_ladder_analysis.py index 4df07a1bae..54cc75d7b7 100644 --- a/experiments/exp2166_scaling_ladder_analysis.py +++ b/experiments/exp2166_scaling_ladder_analysis.py @@ -34,8 +34,6 @@ # Get training steps and datasets for each suite nemotron_training, _ = MARIN_SCALING_SUITES["nemotron"] -dolma3_training, _ = MARIN_SCALING_SUITES["dolma3_mix_150b"] - # --- Scaling Ladder Suites --- # These analyze completed isoflop training runs and optionally train compute-optimal models @@ -54,18 +52,7 @@ wandb_project="marin-analysis", ) - -dolma3_suite = scaling_ladder_suite( - name="exp2166-scaling-ladder-dolma3", - training_runs=dolma3_training, - target_budgets=TARGET_BUDGETS, - label="dolma3-mix-150b-1025", - tokenized=dolma3_mix, - wandb_project="marin-analysis", -) - - -all_steps = [*nemotron_suite.all_steps, *dolma3_suite.all_steps] +all_steps = [*nemotron_suite.all_steps] if __name__ == "__main__": executor_main(steps=all_steps) From 05b7e96fd02a61c8be929e046f88872baa3f29fe Mon Sep 17 00:00:00 2001 From: Helw150 Date: Fri, 26 Dec 2025 09:42:22 -0800 Subject: [PATCH 25/79] More Complete Writing --- lib/levanter/src/levanter/tracker/wandb.py | 28 +++++++++++++++-- .../marin/scaling_laws/eval_metrics_reader.py | 31 +++++++++---------- .../marin/scaling_laws/isoflop_analysis.py | 2 +- 3 files changed, 41 insertions(+), 20 deletions(-) diff --git a/lib/levanter/src/levanter/tracker/wandb.py b/lib/levanter/src/levanter/tracker/wandb.py index 2063f035ed..331d36e6f9 100644 --- a/lib/levanter/src/levanter/tracker/wandb.py +++ b/lib/levanter/src/levanter/tracker/wandb.py @@ -36,7 +36,7 @@ class WandbTracker(Tracker): name: str = "wandb" run: WandbRun - def __init__(self, run: Optional[WandbRun]): + def __init__(self, run: Optional[WandbRun], replicate_path: Optional[str] = None): import wandb if run is None: @@ -52,6 +52,7 @@ def __init__(self, run: Optional[WandbRun]): self.run = run self._last_warning_step = -500 + self._replicate_path = replicate_path def log_hyperparameters(self, hparams: dict[str, Any]): self.run.config.update(_convert_value_to_loggable_rec(hparams), allow_val_change=True) @@ -100,8 +101,28 @@ def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optio def finish(self): logger.info("Finishing wandb run...") + self._write_replicate_file() self.run.finish() + def _write_replicate_file(self): + if self._replicate_path is None: + return + + import json + + import fsspec + + metrics_file = f"{self._replicate_path}/tracker_metrics.jsonl" + fs, _, _ = fsspec.get_fs_token_paths(metrics_file) + fs.makedirs(self._replicate_path, exist_ok=True) + + with fs.open(metrics_file, "w") as f: + record = { + "config": _convert_value_to_loggable_rec(dict(self.run.config)), + "summary": _convert_value_to_loggable_rec(dict(self.run.summary)), + } + f.write(json.dumps(record, sort_keys=True, default=str) + "\n") + def _convert_value_to_loggable_rec(value: Any): if isinstance(value, (list, tuple)): @@ -160,6 +181,9 @@ class WandbConfig(TrackerConfig): save_xla_dumps: bool = False """If True, will save the XLA code to wandb (as configured by XLA_FLAGS). This is useful for debugging.""" + replicate_path: Optional[str] = None + """If set, write config and summary to this path (local or GCS) on finish().""" + def init(self, run_id: Optional[str]) -> WandbTracker: import wandb @@ -240,7 +264,7 @@ def init(self, run_id: Optional[str]) -> WandbTracker: wandb.summary["num_hosts"] = jax.process_count() # type: ignore wandb.summary["backend"] = jax.default_backend() # type: ignore - return WandbTracker(r) + return WandbTracker(r, replicate_path=self.replicate_path) def _git_settings(self): other_settings = dict() diff --git a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py index df48b694c4..7b5af273b0 100644 --- a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py +++ b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py @@ -15,7 +15,7 @@ """Base infrastructure for eval metrics analysis. This module provides a base config and utilities for analysis jobs that -read eval_metrics.jsonl files from completed training runs. The subclassing +read tracker_metrics.jsonl files from completed training runs. The subclassing pattern mirrors the Evaluator approach in lib/marin/src/marin/evaluation/evaluators/evaluator.py, so specific analyses (like IsoFlop) should subclass EvalMetricsAnalysisConfig. @@ -54,11 +54,14 @@ def _backfill_metrics_from_wandb( entity_project: str, ) -> bool: """ - Backfill eval_metrics.jsonl from WandB for a training run. + Backfill tracker_metrics.jsonl from WandB for a training run. + + Writes a single record with config and summary, matching the format + written by WandbTracker.finish() when replicate_path is set. Args: checkpoint_path: Path to the checkpoint directory - metrics_file: Full path to where eval_metrics.jsonl should be written + metrics_file: Full path to where tracker_metrics.jsonl should be written entity_project: WandB entity/project (format: 'entity/project') Returns: @@ -70,30 +73,24 @@ def _backfill_metrics_from_wandb( try: run_id = extract_run_name_from_path(checkpoint_path) - logger.info(f"Attempting to backfill summary metrics for run_id: {run_id}") + logger.info(f"Attempting to backfill metrics for run_id: {run_id}") api = wandb.Api() run = api.run(f"{entity_project}/{run_id}") - # Get summary metrics only - summary = dict(run.summary) - - eval_metrics = {k: v for k, v in summary.items() if k.startswith("eval/")} - if not eval_metrics: - logger.warning(f"No eval summary metrics found in WandB for run {run_id}") - return False + # Build record matching WandbTracker._write_replicate_file format record = { - "step": summary.get("_step", summary.get("trainer/global_step", 0)), - **eval_metrics, + "config": dict(run.config), + "summary": {k: v for k, v in run.summary.items() if not k.startswith("_")}, } fs, _, _ = fsspec.get_fs_token_paths(metrics_file) fs.makedirs(os.path.dirname(metrics_file), exist_ok=True) with fs.open(metrics_file, "w") as f: - f.write(json.dumps(record) + "\n") + f.write(json.dumps(record, sort_keys=True, default=str) + "\n") - logger.info(f"Successfully backfilled summary metrics to {metrics_file}") + logger.info(f"Successfully backfilled metrics to {metrics_file}") return True except Exception as e: @@ -115,11 +112,11 @@ class EvalMetricsAnalysisConfig: output_path: str """Where to write analysis outputs.""" - metrics_filename: str = "eval_metrics.jsonl" + metrics_filename: str = "tracker_metrics.jsonl" """Name of the metrics file within each checkpoint directory.""" backfill_from_wandb: bool = True - """If True, backfill eval_metrics.jsonl from WandB for runs that completed before this feature.""" + """If True, backfill tracker_metrics.jsonl from WandB for runs that completed before this feature.""" wandb_entity_project: str = "marin-community/marin" """WandB entity/project to query for backfill (format: 'entity/project').""" diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index 8c8d21214c..d0ed1618b0 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -731,7 +731,7 @@ def fit_scaling_laws( hidden_dim=int(meta["d"]) if meta else 0, num_layers=int(meta["L"]) if meta else 0, batch_size=int(meta["B"]) if meta else 0, - optimal_params=float(nearest_row.get("params", C / (6 * N_star))), + optimal_params=float(nearest_row.get("params") or C / (6 * N_star)), ) ) From ae288855a19475f89fc55ed100e7571fa2e285df Mon Sep 17 00:00:00 2001 From: Helw150 Date: Fri, 26 Dec 2025 10:04:54 -0800 Subject: [PATCH 26/79] Move to using metadata for all info --- .../marin/scaling_laws/isoflop_analysis.py | 64 ++++++++++++------- 1 file changed, 40 insertions(+), 24 deletions(-) diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index d0ed1618b0..c3d89019ce 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -719,8 +719,6 @@ def fit_scaling_laws( idx = (sub.tokens - N_star).abs().argmin() nearest_row = sub.iloc[idx] - run_name = nearest_row["name"] - meta = parse_isoflop_run_name(run_name) minima_records.append( MinimaRecord( @@ -728,9 +726,9 @@ def fit_scaling_laws( flops=float(C), optimal_tokens=N_star, loss_at_optimal=loss_opt, - hidden_dim=int(meta["d"]) if meta else 0, - num_layers=int(meta["L"]) if meta else 0, - batch_size=int(meta["B"]) if meta else 0, + hidden_dim=int(nearest_row["hidden_dim"]), + num_layers=int(nearest_row["num_layers"]), + batch_size=int(nearest_row["batch_size"]), optimal_params=float(nearest_row.get("params") or C / (6 * N_star)), ) ) @@ -774,9 +772,11 @@ def transform_metrics_for_isoflop( transforms it into the format expected by the analysis: columns: tokens, loss, flops, params, name, label + The DataFrame contains nested 'config' and 'summary' dicts from tracker_metrics.jsonl. + Args: df: Raw metrics DataFrame from read_metrics_dataframe() - metric_key: Which metric column to use for loss + metric_key: Which metric column to use for loss (e.g., 'eval/paloma/c4_en/bpb') label_map: Optional mapping from experiment_name -> display label Returns: @@ -785,39 +785,52 @@ def transform_metrics_for_isoflop( if df.empty: return pd.DataFrame(columns=["tokens", "loss", "flops", "params", "name", "label"]) - # Get final metrics for each run (max step) - final_metrics = df.loc[df.groupby("run_path")["step"].idxmax()].copy() - records = [] - for _, row in final_metrics.iterrows(): + for _, row in df.iterrows(): run_path = row["run_path"] run_name = extract_run_name_from_path(run_path) - meta = parse_isoflop_run_name(run_name) - if meta is None: - logger.warning(f"Could not parse metadata from run name: {run_name}") + + # Extract config and summary dicts + config = row.get("config", {}) or {} + summary = row.get("summary", {}) or {} + model_config = config.get("model", {}) or {} + trainer_config = config.get("trainer", {}) or {} + + # Get tokens directly from summary + tokens = summary.get("throughput/total_tokens") + if tokens is None or pd.isna(tokens): + logger.warning(f"Missing throughput/total_tokens in summary for run {run_name}") continue - flops = meta["flops"] - if flops < 1e18: + # Get total FLOPs from summary (convert GFLOPs to FLOPs) + total_gflops = summary.get("throughput/total_gflops") + if total_gflops is None or pd.isna(total_gflops): + logger.warning(f"Missing throughput/total_gflops in summary for run {run_name}") continue + flops = total_gflops * 1e9 - # Calculate tokens = steps * batch * seq_len - steps = row["step"] - batch = meta["B"] - tokens = steps * batch * SEQ_LEN + if flops < 1e18: + continue - # Get loss from the metric column - loss = row.get(metric_key) + # Get loss from summary[metric_key] + loss = summary.get(metric_key) if loss is None or pd.isna(loss): logger.warning(f"Missing metric {metric_key} for run {run_name}") continue - params = row.get("parameter_count") + # Get parameter count from summary + params = summary.get("parameter_count") if params is None or pd.isna(params): params = None - # Determine label - exp_name = meta["experiment_name"] + # Get model architecture from config + hidden_dim = model_config.get("hidden_dim") + num_layers = model_config.get("num_layers") + batch_size = trainer_config.get("train_batch_size") + + # Determine experiment name and label from run name + meta = parse_isoflop_run_name(run_name) + exp_name = meta["experiment_name"] if meta else run_name if label_map and exp_name in label_map: label = label_map[exp_name] else: @@ -829,6 +842,9 @@ def transform_metrics_for_isoflop( loss=loss, flops=flops, params=params, + hidden_dim=hidden_dim, + num_layers=num_layers, + batch_size=batch_size, name=run_name, label=label, ) From 66a7a307c718578c58616e341978f214c01b3f83 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Fri, 26 Dec 2025 10:18:11 -0800 Subject: [PATCH 27/79] Round counts now they are accurate --- .../exp2166_scaling_ladder_analysis.py | 2 +- .../marin/scaling_laws/isoflop_analysis.py | 26 ++++++++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/experiments/exp2166_scaling_ladder_analysis.py b/experiments/exp2166_scaling_ladder_analysis.py index 54cc75d7b7..8470dfce84 100644 --- a/experiments/exp2166_scaling_ladder_analysis.py +++ b/experiments/exp2166_scaling_ladder_analysis.py @@ -28,7 +28,7 @@ Once complete, results are saved to the output path and uploaded to WandB. """ -from experiments.isoflop_sweep import MARIN_SCALING_SUITES, dolma3_mix, nemotron_mix +from experiments.isoflop_sweep import MARIN_SCALING_SUITES, nemotron_mix from marin.execution.executor import executor_main from marin.scaling_laws import scaling_ladder_suite diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index c3d89019ce..aad554e63b 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -212,6 +212,30 @@ def round_to_power_of_two(x: float) -> int: return 2 ** math.ceil(math.log2(x)) +def round_flops_to_bucket(flops: float) -> float: + """Round FLOP count to 1 significant figure (XeYY format). + + This ensures runs with slightly different achieved FLOPs are grouped + together for analysis when they were targeting the same budget. + + Examples: + 1.05e19 → 1e19 + 1.5e19 → 2e19 + 2.8e19 → 3e19 + 9.5e19 → 1e20 + """ + if flops <= 0: + return flops + + exponent = math.floor(math.log10(flops)) + mantissa = flops / (10**exponent) + rounded_mantissa = round(mantissa) + + if rounded_mantissa == 10: + return 1.0 * (10 ** (exponent + 1)) + return float(rounded_mantissa) * (10**exponent) + + def compute_total_flops( batch: int, num_layers: int, @@ -807,7 +831,7 @@ def transform_metrics_for_isoflop( if total_gflops is None or pd.isna(total_gflops): logger.warning(f"Missing throughput/total_gflops in summary for run {run_name}") continue - flops = total_gflops * 1e9 + flops = round_flops_to_bucket(total_gflops * 1e9) if flops < 1e18: continue From 001af3201d53cf4520b41979c78160b35e2cbb51 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Fri, 26 Dec 2025 11:16:18 -0800 Subject: [PATCH 28/79] New Tests and Merge Main --- .../marin/scaling_laws/isoflop_analysis.py | 6 +- .../src/marin/scaling_laws/scaling_ladder.py | 12 +- tests/test_scaling_laws.py | 394 ++++++++++++++++++ 3 files changed, 409 insertions(+), 3 deletions(-) diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index aad554e63b..0c7c99ad13 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -731,7 +731,11 @@ def fit_scaling_laws( continue # Robust quadratic fit in log10(tokens) - a, b, c = robust_quad_logx(jnp.array(sub.tokens.values), jnp.array(sub.loss.values)) + # Use float64 to avoid int32 overflow for token counts > 2^31 + a, b, c = robust_quad_logx( + jnp.array(sub.tokens.values, dtype=jnp.float64), + jnp.array(sub.loss.values, dtype=jnp.float64), + ) fit_curves[(lab, C)] = (a, b, c) if a == 0: diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py index f03eda2d9f..af398060f8 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -46,11 +46,13 @@ import fsspec import jmp from fray.cluster import ResourceConfig +from haliax.partitioning import ResourceAxis from levanter.checkpoint import CheckpointerConfig from levanter.data.text import LMMixtureDatasetConfig from levanter.main.train_lm import TrainLmConfig from levanter.tracker.wandb import WandbConfig from levanter.trainer import TrainerConfig +from levanter.utils.mesh import MeshConfig from marin.execution.executor import ExecutorStep, InputName, output_path_of, this_output_path from marin.processing.tokenize import get_vocab_size_for_tokenizer @@ -217,7 +219,14 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: save_interval=timedelta(minutes=10), keep=[dict(every=5000)], ), - replica_dcn_axis_size=-1, + mesh=MeshConfig( + # Special axes for MoEs + # TODO: this is actually bad and we should remove, but keeping for now + compute_mapping={ + "token": (ResourceAxis.REPLICA_DCN, ResourceAxis.REPLICA, ResourceAxis.DATA), + "token_repeat": (ResourceAxis.REPLICA_DCN, ResourceAxis.REPLICA, ResourceAxis.DATA), + } + ), allow_nondivisible_batch_size=True, ), train_seq_len=config.seq_len, @@ -288,7 +297,6 @@ def scaling_ladder_rung_step( fn=run_scaling_ladder_rung, config=config, description=f"Scaling ladder rung: optimal training for {target_budget:.1e} FLOPs based on IsoFLOP analysis", - pip_dependency_groups=["tokenize_train"], ) if override_output_path is not None: diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py index b10f30053c..ebcade62bf 100644 --- a/tests/test_scaling_laws.py +++ b/tests/test_scaling_laws.py @@ -33,11 +33,14 @@ candidate_configs, compute_total_flops, compute_transformer_params, + fit_scaling_laws, generate_isoflop_train_args, parse_isoflop_run_name, predict_optimal_config, robust_quad_logx, + round_flops_to_bucket, round_to_power_of_two, + transform_metrics_for_isoflop, ) # --- round_to_power_of_two tests --- @@ -630,3 +633,394 @@ def test_generate_isoflop_train_args_snapshot(): assert ( actual[key] == expected[key] ), f"Config {i}: {key} mismatch: expected {expected[key]}, got {actual[key]}" + + +# --- round_flops_to_bucket tests --- + + +def test_round_flops_to_bucket_exact_values(): + """Test that exact significant figures remain unchanged.""" + assert round_flops_to_bucket(1e18) == 1e18 + assert round_flops_to_bucket(1e19) == 1e19 + assert round_flops_to_bucket(3e19) == 3e19 + assert round_flops_to_bucket(6e19) == 6e19 + assert round_flops_to_bucket(1e20) == 1e20 + + +def test_round_flops_to_bucket_rounds_to_one_significant_figure(): + """Test rounding to 1 significant figure.""" + # Small variations should round to nearest integer mantissa + assert round_flops_to_bucket(1.05e19) == 1e19 + assert round_flops_to_bucket(1.4e19) == 1e19 + assert round_flops_to_bucket(1.5e19) == 2e19 + assert round_flops_to_bucket(2.8e19) == 3e19 + assert round_flops_to_bucket(9.5e19) == 1e20 # Wraps to next power + + +def test_round_flops_to_bucket_handles_edge_cases(): + """Test edge cases for FLOP bucket rounding.""" + assert round_flops_to_bucket(0) == 0 + assert round_flops_to_bucket(-1e18) == -1e18 + # Very large values + assert round_flops_to_bucket(5.5e21) == 6e21 + + +# --- fit_scaling_laws tests --- + + +def test_fit_scaling_laws_empty_dataframe(): + """Test that fit_scaling_laws handles empty dataframe.""" + df = pd.DataFrame() + minima_records, scaling_fits, fit_curves = fit_scaling_laws(df) + assert minima_records == [] + assert scaling_fits == {} + assert fit_curves == {} + + +def test_fit_scaling_laws_single_budget(): + """Test fit_scaling_laws with data from a single FLOP budget.""" + # Create synthetic data with multiple token counts at one budget + df = pd.DataFrame( + { + "tokens": [1e9, 2e9, 4e9, 8e9, 16e9], + "loss": [2.5, 2.2, 2.0, 2.1, 2.3], # U-shaped (optimal around 4e9) + "flops": [1e18, 1e18, 1e18, 1e18, 1e18], + "params": [1e8, 1e8, 1e8, 1e8, 1e8], + "hidden_dim": [512, 512, 512, 512, 512], + "num_layers": [6, 6, 6, 6, 6], + "batch_size": [32, 32, 32, 32, 32], + "name": ["run1", "run2", "run3", "run4", "run5"], + "label": ["nemo", "nemo", "nemo", "nemo", "nemo"], + } + ) + minima_records, scaling_fits, fit_curves = fit_scaling_laws(df) + + # Should find exactly one minimum for the single (label, budget) pair + assert len(minima_records) == 1 + rec = minima_records[0] + assert rec.label == "nemo" + assert rec.flops == 1e18 + # Optimal should be near 4e9 (the minimum loss point) + assert 1e9 < rec.optimal_tokens < 20e9 + + # With only one budget, cannot fit scaling law + assert "nemo" not in scaling_fits + + # Should have fit curve for (nemo, 1e18) + assert ("nemo", 1e18) in fit_curves + + +def test_fit_scaling_laws_multiple_budgets(): + """Test fit_scaling_laws with multiple FLOP budgets to fit scaling law.""" + # Create data with two FLOP budgets + df = pd.DataFrame( + { + "tokens": [ + # Budget 1e18 - optimal around 2e9 + 1e9, + 2e9, + 4e9, + # Budget 1e19 - optimal around 6e9 (more tokens for more compute) + 2e9, + 6e9, + 18e9, + ], + "loss": [ + 2.3, + 2.0, + 2.2, # U-shape at 1e18 + 2.0, + 1.7, + 1.9, # U-shape at 1e19 + ], + "flops": [1e18, 1e18, 1e18, 1e19, 1e19, 1e19], + "params": [1e8, 1e8, 1e8, 5e8, 5e8, 5e8], + "hidden_dim": [512, 512, 512, 1024, 1024, 1024], + "num_layers": [6, 6, 6, 12, 12, 12], + "batch_size": [32, 32, 32, 64, 64, 64], + "name": [f"run{i}" for i in range(6)], + "label": ["nemo"] * 6, + } + ) + minima_records, scaling_fits, fit_curves = fit_scaling_laws(df) + + # Should find two minima (one per budget) + assert len(minima_records) == 2 + + # Should have scaling fit for nemo + assert "nemo" in scaling_fits + alpha, _A = scaling_fits["nemo"] + + # Scaling law alpha should be positive (more compute -> more tokens) + assert 0 < alpha < 1 + + # Should have fit curves for both budgets + assert ("nemo", 1e18) in fit_curves + assert ("nemo", 1e19) in fit_curves + + +def test_fit_scaling_laws_multiple_labels(): + """Test fit_scaling_laws with multiple dataset labels.""" + df = pd.DataFrame( + { + "tokens": [1e9, 2e9, 4e9, 1e9, 2e9, 4e9], + "loss": [2.5, 2.2, 2.4, 2.3, 2.0, 2.2], + "flops": [1e18, 1e18, 1e18, 1e18, 1e18, 1e18], + "params": [1e8] * 6, + "hidden_dim": [512] * 6, + "num_layers": [6] * 6, + "batch_size": [32] * 6, + "name": [f"run{i}" for i in range(6)], + "label": ["nemo", "nemo", "nemo", "dclm", "dclm", "dclm"], + } + ) + minima_records, _scaling_fits, fit_curves = fit_scaling_laws(df) + + # Should find two minima (one per label at the single budget) + assert len(minima_records) == 2 + labels = {rec.label for rec in minima_records} + assert labels == {"nemo", "dclm"} + + # Should have fit curves for both labels + assert ("nemo", 1e18) in fit_curves + assert ("dclm", 1e18) in fit_curves + + +# --- transform_metrics_for_isoflop tests --- + + +# Sample tracker_metrics.jsonl data extracted from real runs +# Note: throughput/total_gflops is in GFLOPs, multiply by 1e9 to get FLOPs +# For 1e18 FLOPs, we need ~1e9 GFLOPs (1e9 * 1e9 = 1e18) +# Need at least 3 data points per budget to fit a quadratic +# Loss values form a U-shape in log(tokens) space for each budget +SAMPLE_METRICS_DATA = [ + # 1e18 budget - 3 runs with U-shaped loss curve + # Too few tokens: model is good but undertrained + # Just right: optimal training + # Too many tokens: model is too small + { + "run_path": "gs://marin/checkpoints/isoflop-1e+18-d1024-L11-B8-nemo-wider-depth-adapt", + "config": { + "model": {"hidden_dim": 1024, "num_layers": 11}, + "trainer": {"train_batch_size": 8}, + }, + "summary": { + "throughput/total_tokens": 1000000000, # 1B tokens (undertrained) + "throughput/total_gflops": 1000000000.0, # 1e9 GFLOPs = 1e18 FLOPs + "eval/paloma/c4_en/bpb": 1.25, # Higher loss - undertrained + "parameter_count": 400000000, + }, + }, + { + "run_path": "gs://marin/checkpoints/isoflop-1e+18-d768-L8-B16-nemo-wider-depth-adapt", + "config": { + "model": {"hidden_dim": 768, "num_layers": 8}, + "trainer": {"train_batch_size": 16}, + }, + "summary": { + "throughput/total_tokens": 2500000000, # 2.5B tokens (optimal) + "throughput/total_gflops": 1000000000.0, # 1e9 GFLOPs = 1e18 FLOPs + "eval/paloma/c4_en/bpb": 1.12, # Lowest loss - optimal + "parameter_count": 272513792, + }, + }, + { + "run_path": "gs://marin/checkpoints/isoflop-1e+18-d512-L6-B32-nemo-wider-depth-adapt", + "config": { + "model": {"hidden_dim": 512, "num_layers": 6}, + "trainer": {"train_batch_size": 32}, + }, + "summary": { + "throughput/total_tokens": 5000000000, # 5B tokens (overtrained/small model) + "throughput/total_gflops": 1000000000.0, # 1e9 GFLOPs = 1e18 FLOPs + "eval/paloma/c4_en/bpb": 1.18, # Higher loss - model too small + "parameter_count": 156508160, + }, + }, + # 1e19 budget - 3 runs with U-shaped loss curve (more tokens optimal) + { + "run_path": "gs://marin/checkpoints/isoflop-1e+19-d2048-L21-B16-nemo-wider-depth-adapt", + "config": { + "model": {"hidden_dim": 2048, "num_layers": 21}, + "trainer": {"train_batch_size": 16}, + }, + "summary": { + "throughput/total_tokens": 3000000000, # 3B tokens (undertrained) + "throughput/total_gflops": 10000000000.0, # 1e10 GFLOPs = 1e19 FLOPs + "eval/paloma/c4_en/bpb": 1.05, # Higher loss - undertrained + "parameter_count": 1800000000, + }, + }, + { + "run_path": "gs://marin/checkpoints/isoflop-1e+19-d1536-L16-B32-nemo-wider-depth-adapt", + "config": { + "model": {"hidden_dim": 1536, "num_layers": 16}, + "trainer": {"train_batch_size": 32}, + }, + "summary": { + "throughput/total_tokens": 8000000000, # 8B tokens (optimal) + "throughput/total_gflops": 10000000000.0, # 1e10 GFLOPs = 1e19 FLOPs + "eval/paloma/c4_en/bpb": 0.98, # Lowest loss - optimal + "parameter_count": 998036992, + }, + }, + { + "run_path": "gs://marin/checkpoints/isoflop-1e+19-d1024-L11-B64-nemo-wider-depth-adapt", + "config": { + "model": {"hidden_dim": 1024, "num_layers": 11}, + "trainer": {"train_batch_size": 64}, + }, + "summary": { + "throughput/total_tokens": 20000000000, # 20B tokens (overtrained) + "throughput/total_gflops": 10000000000.0, # 1e10 GFLOPs = 1e19 FLOPs + "eval/paloma/c4_en/bpb": 1.02, # Higher loss - model too small + "parameter_count": 400000000, + }, + }, +] + + +def test_transform_metrics_for_isoflop_basic(): + """Test basic transformation of metrics data.""" + raw_df = pd.DataFrame(SAMPLE_METRICS_DATA) + metric_key = "eval/paloma/c4_en/bpb" + + result = transform_metrics_for_isoflop(raw_df, metric_key) + + assert len(result) == 6 # 3 runs at 1e18 + 3 runs at 1e19 + assert set(result.columns) == { + "tokens", + "loss", + "flops", + "params", + "hidden_dim", + "num_layers", + "batch_size", + "name", + "label", + } + + # Check that values are extracted correctly - first row is d1024/L11 + row0 = result.iloc[0] + assert row0["tokens"] == 1000000000 # 1B tokens + assert row0["loss"] == 1.25 + assert row0["hidden_dim"] == 1024 + assert row0["num_layers"] == 11 + assert row0["batch_size"] == 8 + + +def test_transform_metrics_for_isoflop_with_label_map(): + """Test transformation with custom label mapping.""" + raw_df = pd.DataFrame(SAMPLE_METRICS_DATA) + metric_key = "eval/paloma/c4_en/bpb" + label_map = {"nemo-wider-depth-adapt": "NeMo"} + + result = transform_metrics_for_isoflop(raw_df, metric_key, label_map) + + assert len(result) == 6 # 3 runs at 1e18 + 3 runs at 1e19 + assert all(result["label"] == "NeMo") + + +def test_transform_metrics_for_isoflop_filters_low_flops(): + """Test that runs with < 1e18 FLOPs are filtered out.""" + raw_df = pd.DataFrame( + [ + { + "run_path": "gs://marin/checkpoints/small-run", + "config": { + "model": {"hidden_dim": 256, "num_layers": 4}, + "trainer": {"train_batch_size": 8}, + }, + "summary": { + "throughput/total_tokens": 1e7, + "throughput/total_gflops": 1e6, # Only 1e15 FLOPs (< 1e18) + "eval/paloma/c4_en/bpb": 3.0, + "parameter_count": 1e7, + }, + } + ] + ) + + result = transform_metrics_for_isoflop(raw_df, "eval/paloma/c4_en/bpb") + assert len(result) == 0 + + +def test_transform_metrics_for_isoflop_empty_input(): + """Test transformation with empty input.""" + raw_df = pd.DataFrame() + result = transform_metrics_for_isoflop(raw_df, "eval/paloma/c4_en/bpb") + assert result.empty + + +def test_transform_metrics_for_isoflop_missing_fields(): + """Test transformation handles missing fields gracefully.""" + raw_df = pd.DataFrame( + [ + { + "run_path": "gs://marin/checkpoints/isoflop-1e+18-d512-L6-B32-incomplete", + "config": {"model": {}, "trainer": {}}, + "summary": { + # Missing throughput/total_tokens + "throughput/total_gflops": 1000001.0, + "eval/paloma/c4_en/bpb": 1.5, + }, + } + ] + ) + + result = transform_metrics_for_isoflop(raw_df, "eval/paloma/c4_en/bpb") + # Should skip the row with missing required fields + assert len(result) == 0 + + +# --- Integration test: fit_scaling_laws with transform_metrics_for_isoflop --- + + +def test_end_to_end_analysis_pipeline(): + """Integration test: transform metrics and fit scaling laws.""" + raw_df = pd.DataFrame(SAMPLE_METRICS_DATA) + + # Transform metrics + isoflop_df = transform_metrics_for_isoflop(raw_df, "eval/paloma/c4_en/bpb") + assert len(isoflop_df) == 6 # 3 runs at 1e18 + 3 runs at 1e19 + + # Fit scaling laws + minima_records, scaling_fits, _fit_curves = fit_scaling_laws(isoflop_df) + + # With 2 budgets (1e18, 1e19), each with 3 points, we should get 2 minima + assert len(minima_records) == 2 + + # Should have a scaling fit for the label + assert len(scaling_fits) == 1 + label = next(iter(scaling_fits.keys())) + alpha, A = scaling_fits[label] + + # Sanity check the scaling law parameters + assert 0 < alpha < 1 # Typical range for token scaling exponent + assert A > 0 + + +def test_minima_records_have_scaling_fit_params(): + """Test that minima records are augmented with scaling fit parameters.""" + df = pd.DataFrame( + { + "tokens": [1e9, 2e9, 4e9, 2e9, 6e9, 18e9], + "loss": [2.3, 2.0, 2.2, 2.0, 1.7, 1.9], + "flops": [1e18, 1e18, 1e18, 1e19, 1e19, 1e19], + "params": [1e8, 1e8, 1e8, 5e8, 5e8, 5e8], + "hidden_dim": [512, 512, 512, 1024, 1024, 1024], + "num_layers": [6, 6, 6, 12, 12, 12], + "batch_size": [32, 32, 32, 64, 64, 64], + "name": [f"run{i}" for i in range(6)], + "label": ["nemo"] * 6, + } + ) + minima_records, scaling_fits, _ = fit_scaling_laws(df) + + # All records for a label with a scaling fit should have the params + for rec in minima_records: + if rec.label in scaling_fits: + alpha, A = scaling_fits[rec.label] + assert rec.scaling_alpha == alpha + assert rec.scaling_A == A From c9349a65723effd33dbd9777aebd55e1ab7542c9 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Fri, 26 Dec 2025 11:34:39 -0800 Subject: [PATCH 29/79] Range Fix --- lib/marin/src/marin/scaling_laws/isoflop_analysis.py | 10 ++++++---- lib/marin/src/marin/scaling_laws/scaling_plots.py | 8 ++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index 0c7c99ad13..e79ce5fb1d 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -732,11 +732,13 @@ def fit_scaling_laws( # Robust quadratic fit in log10(tokens) # Use float64 to avoid int32 overflow for token counts > 2^31 + tokens_array = jnp.array(sub.tokens.values, dtype=jnp.float64) a, b, c = robust_quad_logx( - jnp.array(sub.tokens.values, dtype=jnp.float64), + tokens_array, jnp.array(sub.loss.values, dtype=jnp.float64), ) - fit_curves[(lab, C)] = (a, b, c) + # Store coefficients along with token range used for fitting + fit_curves[(lab, C)] = (a, b, c, float(tokens_array.min()), float(tokens_array.max())) if a == 0: continue @@ -991,8 +993,8 @@ class IsoFlopAnalysisResult: minima_records: list[MinimaRecord] """Raw minima records with detailed info for each optimum.""" - fit_curves: dict[tuple[str, float], tuple[float, float, float]] - """Quadratic fit coefficients {(label, flops): (a, b, c)} for plotting.""" + fit_curves: dict[tuple[str, float], tuple[float, float, float, float, float]] + """Quadratic fit coefficients {(label, flops): (a, b, c, token_min, token_max)} for plotting.""" def to_json_dict(self) -> dict: """Convert result to JSON-serializable dict (excludes DataFrame and fit_curves).""" diff --git a/lib/marin/src/marin/scaling_laws/scaling_plots.py b/lib/marin/src/marin/scaling_laws/scaling_plots.py index ec8512f0d5..66590101d4 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_plots.py +++ b/lib/marin/src/marin/scaling_laws/scaling_plots.py @@ -87,14 +87,14 @@ def create_isoflop_plot( df: pd.DataFrame, minima_records: list, - fit_curves: dict[tuple[str, float], tuple[float, float, float]], + fit_curves: dict[tuple[str, float], tuple[float, float, float, float, float]], ) -> go.Figure: """Create the IsoFLOP plot showing loss vs tokens for each compute budget. Args: df: DataFrame with columns: tokens, loss, flops, params, name, label minima_records: List of dicts with optimal config info per (label, flops) - fit_curves: Dict of {(label, flops): (a, b, c)} quadratic fit coefficients + fit_curves: Dict of {(label, flops): (a, b, c, token_min, token_max)} quadratic fit coefficients Returns: Plotly Figure with the isoflop visualization @@ -140,9 +140,9 @@ def create_isoflop_plot( # Draw fit curve if available key = (lab, C) if key in fit_curves: - a, b, c = fit_curves[key] + a, b, c, token_min, token_max = fit_curves[key] if a != 0: - Ls = jnp.linspace(jnp.log10(sub.tokens.min()), jnp.log10(sub.tokens.max()), 200) + Ls = jnp.linspace(jnp.log10(token_min), jnp.log10(token_max), 200) fig.add_trace( go.Scatter( x=10**Ls, From c9d3039f32a841b81e8c266e6fd52d97fb7134f8 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Fri, 26 Dec 2025 11:40:38 -0800 Subject: [PATCH 30/79] Lint --- lib/marin/src/marin/scaling_laws/isoflop_analysis.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index e79ce5fb1d..9992113725 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -700,7 +700,7 @@ def fit_scaling_laws( ) -> tuple[ list[MinimaRecord], dict[str, tuple[float, float]], - dict[tuple[str, float], tuple[float, float, float]], + dict[tuple[str, float], tuple[float, float, float, float, float]], ]: """ Fit scaling laws and extract optimal configurations. @@ -711,7 +711,8 @@ def fit_scaling_laws( Returns: - minima_records: List of dicts with optimal config info per (label, flops) - scaling_fits: Dict of {label: (alpha, A)} for N* ~ A * C^alpha - - fit_curves: Dict of {(label, flops): (a, b, c)} quadratic coefficients for plotting + - fit_curves: Dict of {(label, flops): (a, b, c, token_min, token_max)} quadratic coefficients + for plotting """ if df is None or df.empty: return [], {}, {} @@ -721,7 +722,7 @@ def fit_scaling_laws( buckets = sorted(df.flops.unique()) minima_records: list[MinimaRecord] = [] - fit_curves: dict[tuple[str, float], tuple[float, float, float]] = {} + fit_curves: dict[tuple[str, float], tuple[float, float, float, float, float]] = {} # Fit quadratic for each (label, budget) and find minima for lab in datasets: From 06c3e5e70e78d5f154e44e39a9bc879bedc8535e Mon Sep 17 00:00:00 2001 From: Helw150 Date: Fri, 26 Dec 2025 11:55:35 -0800 Subject: [PATCH 31/79] Lint --- tests/test_scaling_laws.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py index ebcade62bf..0d55a872fa 100644 --- a/tests/test_scaling_laws.py +++ b/tests/test_scaling_laws.py @@ -478,7 +478,7 @@ def test_create_isoflop_plot_with_data(): optimal_params=1e8, ) ] - fit_curves = {("nemo", 1e18): (0.1, -1.0, 3.0)} + fit_curves = {("nemo", 1e18): (0.1, -1.0, 3.0, 1e9, 3e9)} fig = create_isoflop_plot(df, minima_records, fit_curves) assert fig is not None From cb8e36f611b99aa2a3d390a636d96053c0ace26b Mon Sep 17 00:00:00 2001 From: Helw150 Date: Fri, 26 Dec 2025 12:42:46 -0800 Subject: [PATCH 32/79] Claude Review Comment Fixes --- experiments/exp2166_scaling_ladder_analysis.py | 6 ++---- .../src/marin/scaling_laws/eval_metrics_reader.py | 4 +++- lib/marin/src/marin/scaling_laws/isoflop_analysis.py | 11 ++++++----- lib/marin/src/marin/scaling_laws/scaling_ladder.py | 5 +++-- lib/marin/src/marin/scaling_laws/scaling_plots.py | 6 ++++-- 5 files changed, 18 insertions(+), 14 deletions(-) diff --git a/experiments/exp2166_scaling_ladder_analysis.py b/experiments/exp2166_scaling_ladder_analysis.py index 8470dfce84..f28545e6ca 100644 --- a/experiments/exp2166_scaling_ladder_analysis.py +++ b/experiments/exp2166_scaling_ladder_analysis.py @@ -12,12 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Exp2166: Scaling Ladder Analysis for Nemotron and Dolma3. +"""Exp2166: Scaling Ladder Analysis for Nemotron. This experiment runs scaling ladder analysis on the isoflop training sweeps -for two datasets: -- Nemotron (nemo-wider-depth-adapt) -- Dolma3 (dolma3-mix-150b-1025) +for the Nemotron (nemo-wider-depth-adapt) dataset. The scaling ladder: 1. Fits scaling laws from IsoFLOP sweep data to find compute-optimal configurations diff --git a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py index 7b5af273b0..17b92d232e 100644 --- a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py +++ b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py @@ -30,6 +30,8 @@ import fsspec import pandas as pd +from marin.utilities.wandb_utils import WANDB_ENTITY, WANDB_PROJECT + try: import wandb @@ -118,7 +120,7 @@ class EvalMetricsAnalysisConfig: backfill_from_wandb: bool = True """If True, backfill tracker_metrics.jsonl from WandB for runs that completed before this feature.""" - wandb_entity_project: str = "marin-community/marin" + wandb_entity_project: str = f"{WANDB_ENTITY}/{WANDB_PROJECT}" """WandB entity/project to query for backfill (format: 'entity/project').""" diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index 9992113725..c26411c03d 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -56,6 +56,7 @@ extract_run_name_from_path, read_metrics_dataframe, ) +from marin.utilities.wandb_utils import WANDB_ENTITY, WANDB_PROJECT logger = logging.getLogger(__name__) @@ -1025,10 +1026,10 @@ class IsoFlopAnalysisConfig(EvalMetricsAnalysisConfig): upload_to_wandb: bool = True """Whether to upload plots to WandB.""" - wandb_entity: str = "marin-community" - """WandB entity for uploads.""" + wandb_entity: str = WANDB_ENTITY + """WandB entity for uploads (defaults to WANDB_ENTITY env var or 'marin-community').""" - wandb_project: str = "marin-analysis" + wandb_project: str = f"{WANDB_PROJECT}-analysis" """WandB project for uploads.""" wandb_run_name: str = "scaling-ladder-analysis" @@ -1111,8 +1112,8 @@ def isoflop_analysis_step( label_map: dict[str, str] | None = None, save_plots: bool = True, upload_to_wandb: bool = True, - wandb_entity: str = "marin-community", - wandb_project: str = "marin-analysis", + wandb_entity: str = WANDB_ENTITY, + wandb_project: str = f"{WANDB_PROJECT}-analysis", wandb_run_name: str | None = None, ) -> ExecutorStep: """Create an ExecutorStep for scaling ladder analysis. diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py index af398060f8..897163a3a5 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -67,6 +67,7 @@ predict_optimal_config, ) from marin.training.training import TrainLmOnPodConfig, run_levanter_train_lm +from marin.utilities.wandb_utils import WANDB_ENTITY, WANDB_PROJECT logger = logging.getLogger(__name__) @@ -340,8 +341,8 @@ def scaling_ladder_suite( label_map: dict[str, str] | None = None, save_plots: bool = True, upload_to_wandb: bool = True, - wandb_entity: str = "marin-community", - wandb_project: str = "marin-analysis", + wandb_entity: str = WANDB_ENTITY, + wandb_project: str = f"{WANDB_PROJECT}-analysis", ) -> ScalingLadderSuite: """Create a complete scaling ladder: IsoFLOP analysis + optimal training runs. diff --git a/lib/marin/src/marin/scaling_laws/scaling_plots.py b/lib/marin/src/marin/scaling_laws/scaling_plots.py index 66590101d4..3f0418e52e 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_plots.py +++ b/lib/marin/src/marin/scaling_laws/scaling_plots.py @@ -28,6 +28,8 @@ import plotly.graph_objects as go import plotly.io as pio +from marin.utilities.wandb_utils import WANDB_ENTITY, WANDB_PROJECT + try: import wandb @@ -297,8 +299,8 @@ def save_plots( def upload_plots_to_wandb( fig_isoflop: go.Figure, fig_scaling: go.Figure, - entity: str = "marin-community", - project: str = "marin-analysis", + entity: str = WANDB_ENTITY, + project: str = f"{WANDB_PROJECT}-analysis", run_name: str = "scaling-ladder-analysis", ) -> None: """Upload plots to Weights & Biases. From c395e85f53966f8b23baa39fd8530941007bfe5c Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 5 Jan 2026 15:40:41 -0800 Subject: [PATCH 33/79] Top Level Code for Validation Sets --- .../exp2166_scaling_ladder_analysis.py | 4 ++- .../src/marin/scaling_laws/scaling_ladder.py | 26 +++++++++---------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/experiments/exp2166_scaling_ladder_analysis.py b/experiments/exp2166_scaling_ladder_analysis.py index f28545e6ca..cf26271a63 100644 --- a/experiments/exp2166_scaling_ladder_analysis.py +++ b/experiments/exp2166_scaling_ladder_analysis.py @@ -26,6 +26,7 @@ Once complete, results are saved to the output path and uploaded to WandB. """ +from experiments.defaults import default_validation_sets from experiments.isoflop_sweep import MARIN_SCALING_SUITES, nemotron_mix from marin.execution.executor import executor_main from marin.scaling_laws import scaling_ladder_suite @@ -42,12 +43,13 @@ nemotron_suite = scaling_ladder_suite( - name="exp2166-scaling-ladder-nemotron", + name="exp2166-scaling-ladder-nemotron-validation", training_runs=nemotron_training, target_budgets=TARGET_BUDGETS, label="nemo-wider-depth-adapt", tokenized=nemotron_mix, wandb_project="marin-analysis", + validation_sets=default_validation_sets(tokenizer="stanford-crfm/marin-tokenizer"), ) all_steps = [*nemotron_suite.all_steps] diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py index 897163a3a5..81f0943bb2 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -115,10 +115,6 @@ class ScalingLadderRungConfig: This config references an IsoFLOP analysis step and specifies the target compute budget. At runtime, the optimal config is loaded from the analysis output. - - Note: If you need validation sets, pass an LMMixtureDatasetConfig with - validation sets already configured. This module does not handle default - validation sets to avoid experiment-specific dependencies. """ analysis_output_path: str @@ -131,8 +127,7 @@ class ScalingLadderRungConfig: """Dataset label to use for scaling fit (e.g., 'nemo', 'comma', 'dclm').""" tokenized: InputName | str | LMMixtureDatasetConfig - """Tokenized dataset for training. Can be a path, InputName, or LMMixtureDatasetConfig. - If validation sets are needed, pass an LMMixtureDatasetConfig with them pre-configured.""" + """Tokenized dataset for training. Can be a path, InputName, or LMMixtureDatasetConfig.""" output_path: str """Where to write training outputs.""" @@ -146,6 +141,9 @@ class ScalingLadderRungConfig: sweep_config: IsoFlopSweepConfig | None = None """Optional sweep config for predict_optimal_config. Uses defaults if None.""" + validation_sets: dict[str, TokenizerStep] | None = None + """Optional validation sets to add for eval loss tracking.""" + def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: """Run one rung of the scaling ladder (one compute-optimal training run).""" @@ -197,9 +195,7 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: optimizer_cfg = build_optimizer_config(candidate) - # Accepts both string paths and LMMixtureDatasetConfig. - # If validation sets are needed, they should be pre-configured in the LMMixtureDatasetConfig. - pretraining_data = _prepare_data_config(config.tokenized) + pretraining_data = _prepare_data_config(config.tokenized, config.validation_sets) train_config = TrainLmConfig( data=pretraining_data, @@ -253,6 +249,7 @@ def scaling_ladder_rung_step( tokenizer: str = "stanford-crfm/marin-tokenizer", seq_len: int = 4096, override_output_path: str | None = None, + validation_sets: dict[str, TokenizerStep] | None = None, ) -> ExecutorStep: """Create an ExecutorStep for one rung of the scaling ladder. @@ -265,11 +262,11 @@ def scaling_ladder_rung_step( target_budget: Target compute budget in FLOPs label: Dataset label to use for scaling fit (e.g., 'nemo', 'comma') tokenized: Tokenized dataset to train on. Can be an ExecutorStep, InputName, - or LMMixtureDatasetConfig. If validation sets are needed, pass an - LMMixtureDatasetConfig with them pre-configured. + or LMMixtureDatasetConfig. tokenizer: Tokenizer to use seq_len: Sequence length for training override_output_path: Optional override for the output path + validation_sets: Optional validation sets for eval loss tracking Returns: ExecutorStep configured to run one optimal training run @@ -291,6 +288,7 @@ def scaling_ladder_rung_step( output_path=output_path, tokenizer=tokenizer, seq_len=seq_len, + validation_sets=validation_sets, ) step = ExecutorStep( @@ -343,6 +341,7 @@ def scaling_ladder_suite( upload_to_wandb: bool = True, wandb_entity: str = WANDB_ENTITY, wandb_project: str = f"{WANDB_PROJECT}-analysis", + validation_sets: dict[str, TokenizerStep] | None = None, ) -> ScalingLadderSuite: """Create a complete scaling ladder: IsoFLOP analysis + optimal training runs. @@ -359,8 +358,7 @@ def scaling_ladder_suite( target_budgets: Target compute budgets (in FLOPs) for optimal training label: Dataset label to use for scaling fit (e.g., 'nemo', 'comma') tokenized: Tokenized dataset for optimal training runs. Can be an ExecutorStep, - InputName, or LMMixtureDatasetConfig. If validation sets are needed, - pass an LMMixtureDatasetConfig with them pre-configured. + InputName, or LMMixtureDatasetConfig. tokenizer: Tokenizer to use seq_len: Sequence length for training metric_key: Which metric to use for loss @@ -369,6 +367,7 @@ def scaling_ladder_suite( upload_to_wandb: Whether to upload plots to WandB wandb_entity: WandB entity for uploads wandb_project: WandB project for uploads + validation_sets: Optional validation sets for eval loss tracking Returns: ScalingLadderSuite containing the analysis step and optimal training steps @@ -405,6 +404,7 @@ def scaling_ladder_suite( tokenized=tokenized, tokenizer=tokenizer, seq_len=seq_len, + validation_sets=validation_sets, ) optimal_runs.append(run_step) From 4691d5eb37e5f71055abac367369b962bacb66fa Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 5 Jan 2026 15:57:18 -0800 Subject: [PATCH 34/79] Focus tests in response to Russell PR comments --- tests/test_scaling_laws.py | 790 ++++++------------------------------- 1 file changed, 119 insertions(+), 671 deletions(-) diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py index 0d55a872fa..cdfd4670d9 100644 --- a/tests/test_scaling_laws.py +++ b/tests/test_scaling_laws.py @@ -12,99 +12,82 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for the scaling_laws module.""" +"""Unit tests for the scaling_laws module. + +These tests focus on integration and end-to-end validation with specific expected outputs, +particularly the snapshot test which ensures reproducibility of config generation. +""" import jax.numpy as jnp import pandas as pd +import pytest from marin.scaling_laws.isoflop_analysis import ( - BETA2_BASE, - BETA2_BATCH_DIVISOR, - DEFAULT_BUDGETS, - HIDDEN_HEAD_RATIO, - LR_CONSTANT, MARIN_TOKENIZER_VOCAB_SIZE, - MLP_RATIO, - SEQ_LEN, - CandidateConfig, IsoFlopSweepConfig, IsoFlopTrainArgs, - MinimaRecord, candidate_configs, compute_total_flops, - compute_transformer_params, fit_scaling_laws, generate_isoflop_train_args, parse_isoflop_run_name, - predict_optimal_config, robust_quad_logx, round_flops_to_bucket, round_to_power_of_two, transform_metrics_for_isoflop, ) -# --- round_to_power_of_two tests --- - - -def test_round_to_power_of_two_exact_powers(): - """Test that exact powers of two are unchanged.""" - assert round_to_power_of_two(1) == 1 - assert round_to_power_of_two(2) == 2 - assert round_to_power_of_two(4) == 4 - assert round_to_power_of_two(8) == 8 - assert round_to_power_of_two(16) == 16 - - -def test_round_to_power_of_two_rounds_up(): - """Test that non-powers round up to nearest power of two.""" - assert round_to_power_of_two(3) == 4 - assert round_to_power_of_two(5) == 8 - assert round_to_power_of_two(7) == 8 - assert round_to_power_of_two(9) == 16 - - -def test_round_to_power_of_two_small_values(): - """Test that small/zero values become 1.""" - assert round_to_power_of_two(0.5) == 1 - assert round_to_power_of_two(0.1) == 1 - assert round_to_power_of_two(0) == 1 - - -def test_round_to_power_of_two_large_values(): - """Test rounding for large values.""" - assert round_to_power_of_two(100) == 128 - assert round_to_power_of_two(1000) == 1024 - assert round_to_power_of_two(1025) == 2048 - - -# --- compute_total_flops tests --- +# --- Utility function tests (parametrized) --- + + +@pytest.mark.parametrize( + "value,expected", + [ + # Exact powers unchanged + (1, 1), + (2, 2), + (4, 4), + (16, 16), + # Non-powers round up + (3, 4), + (5, 8), + (9, 16), + # Small/zero values become 1 + (0.5, 1), + (0, 1), + # Large values + (100, 128), + (1000, 1024), + ], +) +def test_round_to_power_of_two(value, expected): + """Test round_to_power_of_two produces correct results.""" + assert round_to_power_of_two(value) == expected + + +@pytest.mark.parametrize( + "value,expected", + [ + # Exact values unchanged + (1e18, 1e18), + (1e19, 1e19), + (3e19, 3e19), + # Rounds to 1 significant figure + (1.05e19, 1e19), + (1.4e19, 1e19), + (1.5e19, 2e19), + (2.8e19, 3e19), + (9.5e19, 1e20), + # Edge cases + (0, 0), + ], +) +def test_round_flops_to_bucket(value, expected): + """Test round_flops_to_bucket rounds to 1 significant figure.""" + assert round_flops_to_bucket(value) == expected -def test_compute_total_flops_larger_model_uses_more_flops(): - """Test that larger models use more FLOPs.""" - small_flops = compute_total_flops( - batch=32, - num_layers=12, - hidden=512, - intermediate=2048, - num_kv_heads=8, - num_heads=8, - steps=1000, - seq_len=4096, - vocab_size=128256, - ) - large_flops = compute_total_flops( - batch=32, - num_layers=24, - hidden=1024, - intermediate=4096, - num_kv_heads=16, - num_heads=16, - steps=1000, - seq_len=4096, - vocab_size=128256, - ) - assert large_flops > small_flops +# --- FLOP computation tests --- def test_compute_total_flops_linear_in_batch_and_steps(): @@ -146,11 +129,12 @@ def test_compute_total_flops_linear_in_batch_and_steps(): assert abs(double_steps_flops - 2 * base_flops) / base_flops < 0.01 -# --- parse_isoflop_run_name tests --- +# --- Run name parsing tests --- -def test_parse_isoflop_run_name_basic(): - """Test parsing a standard isoflop run name.""" +def test_parse_isoflop_run_name(): + """Test parsing isoflop run names extracts correct values.""" + # Standard name result = parse_isoflop_run_name("isoflop-1e+19-d2048-L16-B1024-nemo-wider-depth-adapt") assert result is not None assert result["flops"] == 1e19 @@ -159,37 +143,22 @@ def test_parse_isoflop_run_name_basic(): assert result["B"] == 1024 assert result["experiment_name"] == "nemo-wider-depth-adapt" - -def test_parse_isoflop_run_name_with_hash_suffix(): - """Test parsing run name with hash suffix.""" + # With hash suffix result = parse_isoflop_run_name("isoflop-1e+18-d512-L8-B128-dclm-a1b2c3") assert result is not None assert result["flops"] == 1e18 - assert result["d"] == 512 - assert result["L"] == 8 - assert result["B"] == 128 assert result["experiment_name"] == "dclm" - -def test_parse_isoflop_run_name_invalid_format(): - """Test that invalid formats return None.""" + # Invalid formats return None assert parse_isoflop_run_name("not-a-valid-name") is None - assert parse_isoflop_run_name("isoflop-missing-parts") is None assert parse_isoflop_run_name("") is None -# --- candidate_configs tests --- - - -def test_candidate_configs_generates_candidates(): - """Test that candidate_configs generates at least one config.""" - cfg = IsoFlopSweepConfig() - candidates = list(candidate_configs(cfg, 1e19, MARIN_TOKENIZER_VOCAB_SIZE)) - assert len(candidates) > 0 +# --- Candidate config tests --- def test_candidate_configs_within_tolerance(): - """Test that generated configs are within FLOP tolerance.""" + """Test that generated configs achieve the target FLOP budget within tolerance.""" cfg = IsoFlopSweepConfig(flop_tolerance=0.01) budget = 1e19 for candidate in candidate_configs(cfg, budget, MARIN_TOKENIZER_VOCAB_SIZE): @@ -208,33 +177,14 @@ def test_candidate_configs_within_tolerance(): assert relative_error <= cfg.flop_tolerance -def test_candidate_configs_fields_populated(): - """Test that all candidate fields are properly populated.""" - cfg = IsoFlopSweepConfig() - candidates = list(candidate_configs(cfg, 1e19, MARIN_TOKENIZER_VOCAB_SIZE)) - assert len(candidates) > 0 - - for candidate in candidates: - assert candidate.hidden_size > 0 - assert candidate.intermediate_dim == candidate.hidden_size * MLP_RATIO - assert candidate.num_layers > 0 - assert candidate.num_heads > 0 - assert candidate.num_kv_heads > 0 - assert candidate.batch_size >= 8 - assert candidate.train_steps > 0 - assert candidate.learning_rate > 0 - assert 0 < candidate.beta2 < 1 - assert candidate.tokens > 0 - assert candidate.flops_budget == 1e19 - - -# --- robust_quad_logx tests --- +# --- Curve fitting tests --- def test_robust_quad_logx_fits_quadratic(): - """Test that robust_quad_logx recovers known coefficients.""" + """Test that robust_quad_logx recovers known coefficients from synthetic data.""" x = jnp.array([1e9, 1e10, 1e11, 1e12]) L = jnp.log10(x) + # y = 0.1 * L^2 - 2 * L + 20 y = 0.1 * L**2 - 2 * L + 20 a, b, c = robust_quad_logx(x, y) @@ -244,289 +194,10 @@ def test_robust_quad_logx_fits_quadratic(): assert abs(c - 20) < 0.5 -def test_robust_quad_logx_handles_noise(): - """Test that robust_quad_logx handles noisy data.""" - x = jnp.array([1e9, 1e10, 1e11, 1e12, 1e13]) - L = jnp.log10(x) - y_clean = 0.05 * L**2 - 1.5 * L + 15 - noise = jnp.array([0.01, -0.02, 0.015, -0.01, 0.005]) - y = y_clean + noise - - a, b, _ = robust_quad_logx(x, y) - - assert abs(a - 0.05) < 0.05 - assert abs(b - (-1.5)) < 0.5 - - -# --- predict_optimal_config tests --- - - -def test_predict_optimal_config_unknown_label_returns_none(): - """Test that unknown labels return None.""" - scaling_fits = {"nemo": (0.5, 1e5)} - result = predict_optimal_config( - scaling_fits=scaling_fits, - target_flops=1e21, - label="unknown", - ) - assert result is None - - -def test_predict_optimal_config_valid_label(): - """Test prediction with a valid label.""" - scaling_fits = {"nemo": (0.5, 1e5)} - result = predict_optimal_config( - scaling_fits=scaling_fits, - target_flops=1e20, - label="nemo", - ) - assert result is None or isinstance(result, CandidateConfig) - - -# --- Constants tests --- - - -def test_constants_default_budgets(): - """Test that DEFAULT_BUDGETS is valid.""" - assert len(DEFAULT_BUDGETS) > 0 - assert all(b > 0 for b in DEFAULT_BUDGETS) - assert list(DEFAULT_BUDGETS) == sorted(DEFAULT_BUDGETS) - - -def test_constants_have_expected_values(): - """Test that constants have expected values.""" - assert SEQ_LEN == 4096 - assert MARIN_TOKENIZER_VOCAB_SIZE == 128256 - assert LR_CONSTANT == 0.33 - assert HIDDEN_HEAD_RATIO == 128 - assert BETA2_BASE == 0.98 - assert BETA2_BATCH_DIVISOR == 128 - assert MLP_RATIO == 4 - - -# --- compute_transformer_params tests --- - - -def test_compute_transformer_params_returns_positive_int(): - """Test that compute_transformer_params returns a positive integer.""" - params = compute_transformer_params( - hidden_dim=512, - intermediate_dim=2048, - num_layers=12, - vocab_size=128256, - num_kv_heads=8, - num_heads=8, - ) - assert params > 0 - assert isinstance(params, int) - - -def test_compute_transformer_params_scales_with_hidden_dim(): - """Test that params scale with hidden dimension.""" - small = compute_transformer_params( - hidden_dim=512, - intermediate_dim=2048, - num_layers=12, - vocab_size=128256, - ) - large = compute_transformer_params( - hidden_dim=1024, - intermediate_dim=4096, - num_layers=12, - vocab_size=128256, - ) - assert large > small - - -def test_compute_transformer_params_scales_with_layers(): - """Test that params scale with number of layers.""" - shallow = compute_transformer_params( - hidden_dim=512, - intermediate_dim=2048, - num_layers=6, - vocab_size=128256, - ) - deep = compute_transformer_params( - hidden_dim=512, - intermediate_dim=2048, - num_layers=12, - vocab_size=128256, - ) - assert deep > shallow - - -# --- generate_isoflop_train_args tests --- - - -def test_generate_isoflop_train_args_returns_list(): - """Test that generate_isoflop_train_args returns a list of IsoFlopTrainArgs.""" - config = IsoFlopSweepConfig(budgets=(1e18,)) - result = generate_isoflop_train_args( - sweep_config=config, - experiment_name="test-experiment", - vocab_size=MARIN_TOKENIZER_VOCAB_SIZE, - ) - assert isinstance(result, list) - assert len(result) > 0 - assert all(isinstance(arg, IsoFlopTrainArgs) for arg in result) - - -def test_generate_isoflop_train_args_populates_fields(): - """Test that all required fields are populated.""" - config = IsoFlopSweepConfig(budgets=(1e18,)) - result = generate_isoflop_train_args( - sweep_config=config, - experiment_name="test-experiment", - vocab_size=MARIN_TOKENIZER_VOCAB_SIZE, - ) - for args in result: - assert args.candidate is not None - assert args.candidate.hidden_size > 0 - assert args.candidate.num_layers > 0 - - assert args.model_config is not None - assert args.model_config.hidden_dim == args.candidate.hidden_size - assert args.model_config.num_layers == args.candidate.num_layers - - assert args.optimizer_config is not None - assert args.optimizer_config.learning_rate == args.candidate.learning_rate - - assert args.tpu_type.startswith("v5p-") - assert "isoflop" in args.run_name - assert "test-experiment" in args.run_name - assert len(args.tags) > 0 - assert args.output_path.startswith("checkpoints/isoflop/") - - -def test_generate_isoflop_train_args_more_budgets_more_configs(): - """Test that more budgets produce more configs.""" - config_single = IsoFlopSweepConfig(budgets=(1e18,)) - config_multi = IsoFlopSweepConfig(budgets=(1e18, 1e19)) - - result_single = generate_isoflop_train_args( - sweep_config=config_single, - experiment_name="test", - vocab_size=MARIN_TOKENIZER_VOCAB_SIZE, - ) - result_multi = generate_isoflop_train_args( - sweep_config=config_multi, - experiment_name="test", - vocab_size=MARIN_TOKENIZER_VOCAB_SIZE, - ) - assert len(result_multi) > len(result_single) - - -def test_generate_isoflop_train_args_unique_run_names(): - """Test that all run names are unique.""" - config = IsoFlopSweepConfig(budgets=(1e18, 1e19)) - result = generate_isoflop_train_args( - sweep_config=config, - experiment_name="test", - vocab_size=MARIN_TOKENIZER_VOCAB_SIZE, - ) - run_names = [args.run_name for args in result] - assert len(run_names) == len(set(run_names)) - - -def test_generate_isoflop_train_args_includes_experiment_name(): - """Test that experiment name appears in run names.""" - config = IsoFlopSweepConfig(budgets=(1e18,)) - result = generate_isoflop_train_args( - sweep_config=config, - experiment_name="my-custom-experiment", - vocab_size=MARIN_TOKENIZER_VOCAB_SIZE, - ) - for args in result: - assert "my-custom-experiment" in args.run_name - - -# --- Plotting tests --- - - -def test_create_isoflop_plot_empty_data(): - """Test that create_isoflop_plot handles empty data.""" - from marin.scaling_laws import create_isoflop_plot - - df = pd.DataFrame() - fig = create_isoflop_plot(df, [], {}) - assert fig is not None - - -def test_create_isoflop_plot_with_data(): - """Test create_isoflop_plot with sample data.""" - from marin.scaling_laws import create_isoflop_plot - - df = pd.DataFrame( - { - "tokens": [1e9, 2e9, 3e9], - "loss": [2.5, 2.3, 2.2], - "flops": [1e18, 1e18, 1e18], - "params": [1e8, 1e8, 1e8], - "name": ["run1", "run2", "run3"], - "label": ["nemo", "nemo", "nemo"], - } - ) - minima_records = [ - MinimaRecord( - label="nemo", - flops=1e18, - optimal_tokens=2e9, - loss_at_optimal=2.3, - hidden_dim=512, - num_layers=8, - batch_size=64, - optimal_params=1e8, - ) - ] - fit_curves = {("nemo", 1e18): (0.1, -1.0, 3.0, 1e9, 3e9)} - fig = create_isoflop_plot(df, minima_records, fit_curves) - assert fig is not None - - -def test_create_scaling_plot_empty(): - """Test that create_scaling_plot handles empty data.""" - from marin.scaling_laws import create_scaling_plot - - fig = create_scaling_plot([], {}) - assert fig is not None - - -def test_create_scaling_plot_with_data(): - """Test create_scaling_plot with sample data.""" - from marin.scaling_laws import create_scaling_plot - - minima_records = [ - MinimaRecord( - label="nemo", - flops=1e18, - optimal_tokens=1e9, - loss_at_optimal=2.3, - hidden_dim=512, - num_layers=8, - batch_size=64, - optimal_params=1e8, - ), - MinimaRecord( - label="nemo", - flops=1e19, - optimal_tokens=5e9, - loss_at_optimal=2.1, - hidden_dim=1024, - num_layers=16, - batch_size=128, - optimal_params=5e8, - ), - ] - scaling_fits = {"nemo": (0.5, 1e5)} - fig = create_scaling_plot(minima_records, scaling_fits) - assert fig is not None - - -# --- Snapshot tests --- +# --- Snapshot test for config generation --- # Snapshot of expected output for generate_isoflop_train_args with budget=1e18. -# This captures the configuration generation logic from experiments/isoflop_sweep.py -# on the main branch to ensure the refactored code produces identical configs. +# This captures the configuration generation logic to ensure reproducibility. EXPECTED_ISOFLOP_CONFIGS_1E18 = [ { "hidden_size": 512, @@ -599,8 +270,8 @@ def test_create_scaling_plot_with_data(): def test_generate_isoflop_train_args_snapshot(): """Snapshot test: verify generate_isoflop_train_args produces expected configs. - This test ensures the refactored scaling_laws module produces identical - configurations to the original experiments/isoflop_sweep.py implementation. + This test ensures the scaling_laws module produces identical configurations + for reproducible isoflop sweeps. """ config = IsoFlopSweepConfig(budgets=(1e18,)) result = generate_isoflop_train_args( @@ -614,6 +285,7 @@ def test_generate_isoflop_train_args_snapshot(): ), f"Expected {len(EXPECTED_ISOFLOP_CONFIGS_1E18)} configs, got {len(result)}" for i, (args, expected) in enumerate(zip(result, EXPECTED_ISOFLOP_CONFIGS_1E18, strict=True)): + assert isinstance(args, IsoFlopTrainArgs) c = args.candidate actual = { "hidden_size": c.hidden_size, @@ -635,170 +307,11 @@ def test_generate_isoflop_train_args_snapshot(): ), f"Config {i}: {key} mismatch: expected {expected[key]}, got {actual[key]}" -# --- round_flops_to_bucket tests --- - - -def test_round_flops_to_bucket_exact_values(): - """Test that exact significant figures remain unchanged.""" - assert round_flops_to_bucket(1e18) == 1e18 - assert round_flops_to_bucket(1e19) == 1e19 - assert round_flops_to_bucket(3e19) == 3e19 - assert round_flops_to_bucket(6e19) == 6e19 - assert round_flops_to_bucket(1e20) == 1e20 - - -def test_round_flops_to_bucket_rounds_to_one_significant_figure(): - """Test rounding to 1 significant figure.""" - # Small variations should round to nearest integer mantissa - assert round_flops_to_bucket(1.05e19) == 1e19 - assert round_flops_to_bucket(1.4e19) == 1e19 - assert round_flops_to_bucket(1.5e19) == 2e19 - assert round_flops_to_bucket(2.8e19) == 3e19 - assert round_flops_to_bucket(9.5e19) == 1e20 # Wraps to next power - - -def test_round_flops_to_bucket_handles_edge_cases(): - """Test edge cases for FLOP bucket rounding.""" - assert round_flops_to_bucket(0) == 0 - assert round_flops_to_bucket(-1e18) == -1e18 - # Very large values - assert round_flops_to_bucket(5.5e21) == 6e21 - - -# --- fit_scaling_laws tests --- - - -def test_fit_scaling_laws_empty_dataframe(): - """Test that fit_scaling_laws handles empty dataframe.""" - df = pd.DataFrame() - minima_records, scaling_fits, fit_curves = fit_scaling_laws(df) - assert minima_records == [] - assert scaling_fits == {} - assert fit_curves == {} - - -def test_fit_scaling_laws_single_budget(): - """Test fit_scaling_laws with data from a single FLOP budget.""" - # Create synthetic data with multiple token counts at one budget - df = pd.DataFrame( - { - "tokens": [1e9, 2e9, 4e9, 8e9, 16e9], - "loss": [2.5, 2.2, 2.0, 2.1, 2.3], # U-shaped (optimal around 4e9) - "flops": [1e18, 1e18, 1e18, 1e18, 1e18], - "params": [1e8, 1e8, 1e8, 1e8, 1e8], - "hidden_dim": [512, 512, 512, 512, 512], - "num_layers": [6, 6, 6, 6, 6], - "batch_size": [32, 32, 32, 32, 32], - "name": ["run1", "run2", "run3", "run4", "run5"], - "label": ["nemo", "nemo", "nemo", "nemo", "nemo"], - } - ) - minima_records, scaling_fits, fit_curves = fit_scaling_laws(df) - - # Should find exactly one minimum for the single (label, budget) pair - assert len(minima_records) == 1 - rec = minima_records[0] - assert rec.label == "nemo" - assert rec.flops == 1e18 - # Optimal should be near 4e9 (the minimum loss point) - assert 1e9 < rec.optimal_tokens < 20e9 - - # With only one budget, cannot fit scaling law - assert "nemo" not in scaling_fits - - # Should have fit curve for (nemo, 1e18) - assert ("nemo", 1e18) in fit_curves - - -def test_fit_scaling_laws_multiple_budgets(): - """Test fit_scaling_laws with multiple FLOP budgets to fit scaling law.""" - # Create data with two FLOP budgets - df = pd.DataFrame( - { - "tokens": [ - # Budget 1e18 - optimal around 2e9 - 1e9, - 2e9, - 4e9, - # Budget 1e19 - optimal around 6e9 (more tokens for more compute) - 2e9, - 6e9, - 18e9, - ], - "loss": [ - 2.3, - 2.0, - 2.2, # U-shape at 1e18 - 2.0, - 1.7, - 1.9, # U-shape at 1e19 - ], - "flops": [1e18, 1e18, 1e18, 1e19, 1e19, 1e19], - "params": [1e8, 1e8, 1e8, 5e8, 5e8, 5e8], - "hidden_dim": [512, 512, 512, 1024, 1024, 1024], - "num_layers": [6, 6, 6, 12, 12, 12], - "batch_size": [32, 32, 32, 64, 64, 64], - "name": [f"run{i}" for i in range(6)], - "label": ["nemo"] * 6, - } - ) - minima_records, scaling_fits, fit_curves = fit_scaling_laws(df) - - # Should find two minima (one per budget) - assert len(minima_records) == 2 - - # Should have scaling fit for nemo - assert "nemo" in scaling_fits - alpha, _A = scaling_fits["nemo"] - - # Scaling law alpha should be positive (more compute -> more tokens) - assert 0 < alpha < 1 - - # Should have fit curves for both budgets - assert ("nemo", 1e18) in fit_curves - assert ("nemo", 1e19) in fit_curves - - -def test_fit_scaling_laws_multiple_labels(): - """Test fit_scaling_laws with multiple dataset labels.""" - df = pd.DataFrame( - { - "tokens": [1e9, 2e9, 4e9, 1e9, 2e9, 4e9], - "loss": [2.5, 2.2, 2.4, 2.3, 2.0, 2.2], - "flops": [1e18, 1e18, 1e18, 1e18, 1e18, 1e18], - "params": [1e8] * 6, - "hidden_dim": [512] * 6, - "num_layers": [6] * 6, - "batch_size": [32] * 6, - "name": [f"run{i}" for i in range(6)], - "label": ["nemo", "nemo", "nemo", "dclm", "dclm", "dclm"], - } - ) - minima_records, _scaling_fits, fit_curves = fit_scaling_laws(df) - - # Should find two minima (one per label at the single budget) - assert len(minima_records) == 2 - labels = {rec.label for rec in minima_records} - assert labels == {"nemo", "dclm"} - - # Should have fit curves for both labels - assert ("nemo", 1e18) in fit_curves - assert ("dclm", 1e18) in fit_curves - - -# --- transform_metrics_for_isoflop tests --- - +# --- Metrics transformation tests --- # Sample tracker_metrics.jsonl data extracted from real runs -# Note: throughput/total_gflops is in GFLOPs, multiply by 1e9 to get FLOPs -# For 1e18 FLOPs, we need ~1e9 GFLOPs (1e9 * 1e9 = 1e18) -# Need at least 3 data points per budget to fit a quadratic -# Loss values form a U-shape in log(tokens) space for each budget SAMPLE_METRICS_DATA = [ # 1e18 budget - 3 runs with U-shaped loss curve - # Too few tokens: model is good but undertrained - # Just right: optimal training - # Too many tokens: model is too small { "run_path": "gs://marin/checkpoints/isoflop-1e+18-d1024-L11-B8-nemo-wider-depth-adapt", "config": { @@ -806,9 +319,9 @@ def test_fit_scaling_laws_multiple_labels(): "trainer": {"train_batch_size": 8}, }, "summary": { - "throughput/total_tokens": 1000000000, # 1B tokens (undertrained) - "throughput/total_gflops": 1000000000.0, # 1e9 GFLOPs = 1e18 FLOPs - "eval/paloma/c4_en/bpb": 1.25, # Higher loss - undertrained + "throughput/total_tokens": 1000000000, + "throughput/total_gflops": 1000000000.0, + "eval/paloma/c4_en/bpb": 1.25, "parameter_count": 400000000, }, }, @@ -819,9 +332,9 @@ def test_fit_scaling_laws_multiple_labels(): "trainer": {"train_batch_size": 16}, }, "summary": { - "throughput/total_tokens": 2500000000, # 2.5B tokens (optimal) - "throughput/total_gflops": 1000000000.0, # 1e9 GFLOPs = 1e18 FLOPs - "eval/paloma/c4_en/bpb": 1.12, # Lowest loss - optimal + "throughput/total_tokens": 2500000000, + "throughput/total_gflops": 1000000000.0, + "eval/paloma/c4_en/bpb": 1.12, "parameter_count": 272513792, }, }, @@ -832,13 +345,13 @@ def test_fit_scaling_laws_multiple_labels(): "trainer": {"train_batch_size": 32}, }, "summary": { - "throughput/total_tokens": 5000000000, # 5B tokens (overtrained/small model) - "throughput/total_gflops": 1000000000.0, # 1e9 GFLOPs = 1e18 FLOPs - "eval/paloma/c4_en/bpb": 1.18, # Higher loss - model too small + "throughput/total_tokens": 5000000000, + "throughput/total_gflops": 1000000000.0, + "eval/paloma/c4_en/bpb": 1.18, "parameter_count": 156508160, }, }, - # 1e19 budget - 3 runs with U-shaped loss curve (more tokens optimal) + # 1e19 budget - 3 runs { "run_path": "gs://marin/checkpoints/isoflop-1e+19-d2048-L21-B16-nemo-wider-depth-adapt", "config": { @@ -846,9 +359,9 @@ def test_fit_scaling_laws_multiple_labels(): "trainer": {"train_batch_size": 16}, }, "summary": { - "throughput/total_tokens": 3000000000, # 3B tokens (undertrained) - "throughput/total_gflops": 10000000000.0, # 1e10 GFLOPs = 1e19 FLOPs - "eval/paloma/c4_en/bpb": 1.05, # Higher loss - undertrained + "throughput/total_tokens": 3000000000, + "throughput/total_gflops": 10000000000.0, + "eval/paloma/c4_en/bpb": 1.05, "parameter_count": 1800000000, }, }, @@ -859,9 +372,9 @@ def test_fit_scaling_laws_multiple_labels(): "trainer": {"train_batch_size": 32}, }, "summary": { - "throughput/total_tokens": 8000000000, # 8B tokens (optimal) - "throughput/total_gflops": 10000000000.0, # 1e10 GFLOPs = 1e19 FLOPs - "eval/paloma/c4_en/bpb": 0.98, # Lowest loss - optimal + "throughput/total_tokens": 8000000000, + "throughput/total_gflops": 10000000000.0, + "eval/paloma/c4_en/bpb": 0.98, "parameter_count": 998036992, }, }, @@ -872,57 +385,36 @@ def test_fit_scaling_laws_multiple_labels(): "trainer": {"train_batch_size": 64}, }, "summary": { - "throughput/total_tokens": 20000000000, # 20B tokens (overtrained) - "throughput/total_gflops": 10000000000.0, # 1e10 GFLOPs = 1e19 FLOPs - "eval/paloma/c4_en/bpb": 1.02, # Higher loss - model too small + "throughput/total_tokens": 20000000000, + "throughput/total_gflops": 10000000000.0, + "eval/paloma/c4_en/bpb": 1.02, "parameter_count": 400000000, }, }, ] -def test_transform_metrics_for_isoflop_basic(): - """Test basic transformation of metrics data.""" +def test_transform_metrics_for_isoflop(): + """Test transformation of raw metrics data to isoflop analysis format.""" raw_df = pd.DataFrame(SAMPLE_METRICS_DATA) metric_key = "eval/paloma/c4_en/bpb" result = transform_metrics_for_isoflop(raw_df, metric_key) assert len(result) == 6 # 3 runs at 1e18 + 3 runs at 1e19 - assert set(result.columns) == { - "tokens", - "loss", - "flops", - "params", - "hidden_dim", - "num_layers", - "batch_size", - "name", - "label", - } - # Check that values are extracted correctly - first row is d1024/L11 + # Verify specific values from first row (d1024/L11) row0 = result.iloc[0] - assert row0["tokens"] == 1000000000 # 1B tokens + assert row0["tokens"] == 1000000000 assert row0["loss"] == 1.25 assert row0["hidden_dim"] == 1024 assert row0["num_layers"] == 11 assert row0["batch_size"] == 8 + assert row0["flops"] == 1e18 + assert row0["params"] == 400000000 -def test_transform_metrics_for_isoflop_with_label_map(): - """Test transformation with custom label mapping.""" - raw_df = pd.DataFrame(SAMPLE_METRICS_DATA) - metric_key = "eval/paloma/c4_en/bpb" - label_map = {"nemo-wider-depth-adapt": "NeMo"} - - result = transform_metrics_for_isoflop(raw_df, metric_key, label_map) - - assert len(result) == 6 # 3 runs at 1e18 + 3 runs at 1e19 - assert all(result["label"] == "NeMo") - - -def test_transform_metrics_for_isoflop_filters_low_flops(): +def test_transform_metrics_filters_low_flops(): """Test that runs with < 1e18 FLOPs are filtered out.""" raw_df = pd.DataFrame( [ @@ -934,7 +426,7 @@ def test_transform_metrics_for_isoflop_filters_low_flops(): }, "summary": { "throughput/total_tokens": 1e7, - "throughput/total_gflops": 1e6, # Only 1e15 FLOPs (< 1e18) + "throughput/total_gflops": 1e6, # Only 1e15 FLOPs "eval/paloma/c4_en/bpb": 3.0, "parameter_count": 1e7, }, @@ -946,81 +438,37 @@ def test_transform_metrics_for_isoflop_filters_low_flops(): assert len(result) == 0 -def test_transform_metrics_for_isoflop_empty_input(): - """Test transformation with empty input.""" - raw_df = pd.DataFrame() - result = transform_metrics_for_isoflop(raw_df, "eval/paloma/c4_en/bpb") - assert result.empty - - -def test_transform_metrics_for_isoflop_missing_fields(): - """Test transformation handles missing fields gracefully.""" - raw_df = pd.DataFrame( - [ - { - "run_path": "gs://marin/checkpoints/isoflop-1e+18-d512-L6-B32-incomplete", - "config": {"model": {}, "trainer": {}}, - "summary": { - # Missing throughput/total_tokens - "throughput/total_gflops": 1000001.0, - "eval/paloma/c4_en/bpb": 1.5, - }, - } - ] - ) - - result = transform_metrics_for_isoflop(raw_df, "eval/paloma/c4_en/bpb") - # Should skip the row with missing required fields - assert len(result) == 0 - - -# --- Integration test: fit_scaling_laws with transform_metrics_for_isoflop --- +# --- End-to-end integration test --- def test_end_to_end_analysis_pipeline(): - """Integration test: transform metrics and fit scaling laws.""" + """Integration test: transform metrics and fit scaling laws end-to-end. + + Uses SAMPLE_METRICS_DATA (simulating real wandb metrics) to verify the full + pipeline: metrics transformation -> curve fitting -> scaling law extraction. + """ raw_df = pd.DataFrame(SAMPLE_METRICS_DATA) # Transform metrics isoflop_df = transform_metrics_for_isoflop(raw_df, "eval/paloma/c4_en/bpb") - assert len(isoflop_df) == 6 # 3 runs at 1e18 + 3 runs at 1e19 + assert len(isoflop_df) == 6 # Fit scaling laws - minima_records, scaling_fits, _fit_curves = fit_scaling_laws(isoflop_df) + minima_records, _scaling_fits, _ = fit_scaling_laws(isoflop_df) - # With 2 budgets (1e18, 1e19), each with 3 points, we should get 2 minima + # Should find two minima (one per budget: 1e18 and 1e19) assert len(minima_records) == 2 + flops_budgets = {rec.flops for rec in minima_records} + assert flops_budgets == {1e18, 1e19} - # Should have a scaling fit for the label - assert len(scaling_fits) == 1 - label = next(iter(scaling_fits.keys())) - alpha, A = scaling_fits[label] - - # Sanity check the scaling law parameters - assert 0 < alpha < 1 # Typical range for token scaling exponent - assert A > 0 + # Verify fitted minima are near expected optimal points + # Curve fitting interpolates to find analytical minimum of fitted quadratic + minima_by_flops = {rec.flops: rec for rec in minima_records} + # At 1e18: raw data optimal at 2.5B (loss=1.12), fitted minimum ~2.6B + assert abs(minima_by_flops[1e18].optimal_tokens - 2.6e9) < 0.2e9 + assert abs(minima_by_flops[1e18].loss_at_optimal - 1.12) < 0.01 -def test_minima_records_have_scaling_fit_params(): - """Test that minima records are augmented with scaling fit parameters.""" - df = pd.DataFrame( - { - "tokens": [1e9, 2e9, 4e9, 2e9, 6e9, 18e9], - "loss": [2.3, 2.0, 2.2, 2.0, 1.7, 1.9], - "flops": [1e18, 1e18, 1e18, 1e19, 1e19, 1e19], - "params": [1e8, 1e8, 1e8, 5e8, 5e8, 5e8], - "hidden_dim": [512, 512, 512, 1024, 1024, 1024], - "num_layers": [6, 6, 6, 12, 12, 12], - "batch_size": [32, 32, 32, 64, 64, 64], - "name": [f"run{i}" for i in range(6)], - "label": ["nemo"] * 6, - } - ) - minima_records, scaling_fits, _ = fit_scaling_laws(df) - - # All records for a label with a scaling fit should have the params - for rec in minima_records: - if rec.label in scaling_fits: - alpha, A = scaling_fits[rec.label] - assert rec.scaling_alpha == alpha - assert rec.scaling_A == A + # At 1e19: raw data optimal at 8B (loss=0.98), fitted minimum ~8.8B + assert abs(minima_by_flops[1e19].optimal_tokens - 8.8e9) < 0.2e9 + assert abs(minima_by_flops[1e19].loss_at_optimal - 0.98) < 0.01 From 23c87a71e73c8cbcf816a066f2005cf3b03cb173 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 5 Jan 2026 16:16:51 -0800 Subject: [PATCH 35/79] Remove Parsing Since I moved to getting metadata directly --- .../marin/scaling_laws/isoflop_analysis.py | 20 ++++++------------- tests/test_scaling_laws.py | 13 +++--------- 2 files changed, 9 insertions(+), 24 deletions(-) diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index c26411c03d..937f536d32 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -640,32 +640,25 @@ def generate_isoflop_train_args( # ---------------- Helpers ---------------- -def parse_isoflop_run_name(run_name: str) -> dict | None: - """Parse metadata from isoflop run name. +def parse_isoflop_run_name(run_name: str) -> str | None: + """Parse experiment name from isoflop run name. Expected format: isoflop-{budget}-d{hidden}-L{layers}-B{batch}-{experiment_name} Optionally with a trailing - which is ignored. E.g., 'isoflop-1e+19-d2048-L16-B1024-nemo-wider-depth-adapt' or 'isoflop-1e+19-d2048-L16-B1024-nemo-wider-depth-adapt-a1b2c3' - Returns dict with: flops, d, L, B, experiment_name or None if parsing fails. + Returns experiment_name or None if parsing fails. """ # Strip optional - suffix run_name = re.sub(r"-[0-9a-fA-F]{6}$", "", run_name) - pattern = r"isoflop-([0-9.e+]+)-d(\d+)-L(\d+)-B(\d+)-(.+)" + pattern = r"isoflop-(?:[0-9.e+]+)-d(?:\d+)-L(?:\d+)-B(?:\d+)-(.+)" match = re.match(pattern, run_name) if not match: return None - flops_str, d, L, B, exp_name = match.groups() - return { - "flops": float(flops_str), - "d": int(d), - "L": int(L), - "B": int(B), - "experiment_name": exp_name, - } + return match.group(1) def robust_quad_logx(x: jnp.ndarray, y: jnp.ndarray, delta: float = 1.0) -> tuple[float, float, float]: @@ -861,8 +854,7 @@ def transform_metrics_for_isoflop( batch_size = trainer_config.get("train_batch_size") # Determine experiment name and label from run name - meta = parse_isoflop_run_name(run_name) - exp_name = meta["experiment_name"] if meta else run_name + exp_name = parse_isoflop_run_name(run_name) or run_name if label_map and exp_name in label_map: label = label_map[exp_name] else: diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py index cdfd4670d9..2c130039e1 100644 --- a/tests/test_scaling_laws.py +++ b/tests/test_scaling_laws.py @@ -133,21 +133,14 @@ def test_compute_total_flops_linear_in_batch_and_steps(): def test_parse_isoflop_run_name(): - """Test parsing isoflop run names extracts correct values.""" + """Test parsing isoflop run names extracts experiment names.""" # Standard name result = parse_isoflop_run_name("isoflop-1e+19-d2048-L16-B1024-nemo-wider-depth-adapt") - assert result is not None - assert result["flops"] == 1e19 - assert result["d"] == 2048 - assert result["L"] == 16 - assert result["B"] == 1024 - assert result["experiment_name"] == "nemo-wider-depth-adapt" + assert result == "nemo-wider-depth-adapt" # With hash suffix result = parse_isoflop_run_name("isoflop-1e+18-d512-L8-B128-dclm-a1b2c3") - assert result is not None - assert result["flops"] == 1e18 - assert result["experiment_name"] == "dclm" + assert result == "dclm" # Invalid formats return None assert parse_isoflop_run_name("not-a-valid-name") is None From 0ffaa6eb7a99f85d5e667bfa5bfe3feb9c3dfc60 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 5 Jan 2026 16:18:50 -0800 Subject: [PATCH 36/79] Keep Only More General Test --- tests/test_scaling_laws.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py index 2c130039e1..efa43046a6 100644 --- a/tests/test_scaling_laws.py +++ b/tests/test_scaling_laws.py @@ -134,11 +134,6 @@ def test_compute_total_flops_linear_in_batch_and_steps(): def test_parse_isoflop_run_name(): """Test parsing isoflop run names extracts experiment names.""" - # Standard name - result = parse_isoflop_run_name("isoflop-1e+19-d2048-L16-B1024-nemo-wider-depth-adapt") - assert result == "nemo-wider-depth-adapt" - - # With hash suffix result = parse_isoflop_run_name("isoflop-1e+18-d512-L8-B128-dclm-a1b2c3") assert result == "dclm" From 23ccd6b59ac6f5ec2ebdd823664ffa88ebd4458c Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 5 Jan 2026 16:40:06 -0800 Subject: [PATCH 37/79] Oversafe Claude --- lib/marin/src/marin/scaling_laws/scaling_plots.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/lib/marin/src/marin/scaling_laws/scaling_plots.py b/lib/marin/src/marin/scaling_laws/scaling_plots.py index 3f0418e52e..a971b46a08 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_plots.py +++ b/lib/marin/src/marin/scaling_laws/scaling_plots.py @@ -30,12 +30,7 @@ from marin.utilities.wandb_utils import WANDB_ENTITY, WANDB_PROJECT -try: - import wandb - - WANDB_AVAILABLE = True -except ImportError: - WANDB_AVAILABLE = False +import wandb logger = logging.getLogger(__name__) @@ -312,10 +307,6 @@ def upload_plots_to_wandb( project: WandB project run_name: Name for the WandB run """ - if not WANDB_AVAILABLE: - logger.warning("wandb not available, cannot upload plots") - return - wandb.login() run = wandb.init( entity=entity, From c922953069447936a6b9f2ec17aa1351dee63b6c Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 5 Jan 2026 18:03:27 -0800 Subject: [PATCH 38/79] Split Analysis, Plotting, and WandB steps --- lib/marin/src/marin/scaling_laws/__init__.py | 8 + .../marin/scaling_laws/isoflop_analysis.py | 274 +++++++++++++++--- .../src/marin/scaling_laws/scaling_plots.py | 1 - 3 files changed, 234 insertions(+), 49 deletions(-) diff --git a/lib/marin/src/marin/scaling_laws/__init__.py b/lib/marin/src/marin/scaling_laws/__init__.py index 83ae5adb21..79e7f48cab 100644 --- a/lib/marin/src/marin/scaling_laws/__init__.py +++ b/lib/marin/src/marin/scaling_laws/__init__.py @@ -17,19 +17,23 @@ CandidateConfig, IsoFlopAnalysisConfig, IsoFlopAnalysisResult, + IsoFlopPlotsConfig, IsoFlopSweepConfig, IsoFlopTrainArgs, MinimaRecord, + UploadPlotsToWandbConfig, build_model_config, build_optimizer_config, candidate_configs, compute_transformer_params, generate_isoflop_train_args, isoflop_analysis_step, + isoflop_plots_step, pick_v5p_type, predict_optimal_config, predict_optimal_configs_for_budgets, run_isoflop_analysis, + upload_isoflop_plots_to_wandb_step, ) from marin.scaling_laws.scaling_ladder import ( ScalingLadderRungConfig, @@ -49,11 +53,13 @@ "CandidateConfig", "IsoFlopAnalysisConfig", "IsoFlopAnalysisResult", + "IsoFlopPlotsConfig", "IsoFlopSweepConfig", "IsoFlopTrainArgs", "MinimaRecord", "ScalingLadderRungConfig", "ScalingLadderSuite", + "UploadPlotsToWandbConfig", "build_model_config", "build_optimizer_config", "candidate_configs", @@ -62,6 +68,7 @@ "create_scaling_plot", "generate_isoflop_train_args", "isoflop_analysis_step", + "isoflop_plots_step", "pick_v5p_type", "predict_optimal_config", "predict_optimal_configs_for_budgets", @@ -69,5 +76,6 @@ "save_plots", "scaling_ladder_rung_step", "scaling_ladder_suite", + "upload_isoflop_plots_to_wandb_step", "upload_plots_to_wandb", ] diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index 937f536d32..7887637014 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -14,19 +14,36 @@ """IsoFLOP analysis for finding compute-optimal training configurations. -Primary usage - create an ExecutorStep for your pipeline: +Primary usage - create ExecutorSteps for your pipeline: - from marin.scaling_laws import isoflop_analysis_step + from marin.scaling_laws import ( + isoflop_analysis_step, + isoflop_plots_step, + upload_isoflop_plots_to_wandb_step, + ) + # Step 1: Compute metrics and fit scaling laws analysis = isoflop_analysis_step( name="my-scaling-analysis", training_runs=my_training_steps, # list of ExecutorStep ) -The step will: + # Step 2: Generate HTML plots (optional) + plots = isoflop_plots_step( + name="my-scaling-plots", + analysis_step=analysis, + ) + + # Step 3: Upload to WandB (optional) + upload = upload_isoflop_plots_to_wandb_step( + name="upload-scaling-plots", + analysis_step=analysis, + ) + +The analysis step will: 1. Read eval metrics from completed training runs 2. Fit scaling laws to find compute-optimal token counts -3. Save plots and results to the output path +3. Save results to JSON/parquet files For programmatic use, see `run_isoflop_analysis()` which returns a `IsoFlopAnalysisResult`. """ @@ -999,6 +1016,13 @@ def to_json_dict(self) -> dict: } +def _parse_fit_curve_coeffs(coeffs: Sequence[float]) -> tuple[float, float, float, float, float]: + if len(coeffs) != 5: + raise ValueError(f"Expected 5 fit curve coefficients, got {len(coeffs)}") + a, b, c, token_min, token_max = coeffs + return (float(a), float(b), float(c), float(token_min), float(token_max)) + + # ---------------- ExecutorStep Config ---------------- @@ -1012,14 +1036,27 @@ class IsoFlopAnalysisConfig(EvalMetricsAnalysisConfig): label_map: tuple[tuple[str, str], ...] | None = None """Optional mapping from experiment_name -> display label as tuple of pairs.""" - save_plots: bool = True - """Whether to save HTML plots to output_path.""" - upload_to_wandb: bool = True - """Whether to upload plots to WandB.""" +@dataclass(frozen=True) +class IsoFlopPlotsConfig: + """Configuration for isoflop plots ExecutorStep.""" + + analysis_output_path: str + """Path to the isoflop analysis output (containing isoflop_analysis_result.json).""" + + output_path: str + """Path to save the HTML plots.""" + + +@dataclass(frozen=True) +class UploadPlotsToWandbConfig: + """Configuration for uploading plots to WandB.""" + + plots_path: str + """Path to the directory containing HTML plots.""" wandb_entity: str = WANDB_ENTITY - """WandB entity for uploads (defaults to WANDB_ENTITY env var or 'marin-community').""" + """WandB entity for uploads.""" wandb_project: str = f"{WANDB_PROJECT}-analysis" """WandB project for uploads.""" @@ -1071,27 +1108,102 @@ def _run_isoflop_analysis_step(config: IsoFlopAnalysisConfig) -> None: json.dump(result.to_json_dict(), f, indent=2) logger.info(f"Saved results to {result_path}") - if config.save_plots: - from marin.scaling_laws.scaling_plots import ( - create_isoflop_plot, - create_scaling_plot, - save_plots, - ) + # Also save the full dataframe and fit curves for downstream plotting + df_path = os.path.join(config.output_path, "isoflop_df.parquet") + isoflop_df.to_parquet(df_path) + logger.info(f"Saved dataframe to {df_path}") + + fit_curves_path = os.path.join(config.output_path, "fit_curves.json") + # Convert tuple keys to strings for JSON serialization + fit_curves_json = {f"{label}|{flops}": list(coeffs) for (label, flops), coeffs in fit_curves.items()} + with fs.open(fit_curves_path, "w") as f: + json.dump(fit_curves_json, f, indent=2) + logger.info(f"Saved fit curves to {fit_curves_path}") + + +def _run_isoflop_plots_step(config: IsoFlopPlotsConfig) -> None: + """Generate and save isoflop plots (called by ExecutorStep).""" + from marin.scaling_laws.scaling_plots import ( + create_isoflop_plot, + create_scaling_plot, + save_plots, + ) - fig_isoflop = create_isoflop_plot(isoflop_df, minima_records, fit_curves) - fig_scaling = create_scaling_plot(minima_records, scaling_fits) - save_plots(fig_isoflop, fig_scaling, config.output_path) + fs, _, _ = fsspec.get_fs_token_paths(config.analysis_output_path) - if config.upload_to_wandb: - from marin.scaling_laws.scaling_plots import upload_plots_to_wandb + # Load the analysis results + result_path = os.path.join(config.analysis_output_path, "isoflop_analysis_result.json") + with fs.open(result_path, "r") as f: + result_dict = json.load(f) - upload_plots_to_wandb( - fig_isoflop, - fig_scaling, - entity=config.wandb_entity, - project=config.wandb_project, - run_name=config.wandb_run_name, - ) + # Load the dataframe + df_path = os.path.join(config.analysis_output_path, "isoflop_df.parquet") + isoflop_df = pd.read_parquet(df_path) + + # Load fit curves and reconstruct tuple keys + fit_curves_path = os.path.join(config.analysis_output_path, "fit_curves.json") + with fs.open(fit_curves_path, "r") as f: + fit_curves_json = json.load(f) + fit_curves: dict[tuple[str, float], tuple[float, float, float, float, float]] = {} + for key_str, coeffs in fit_curves_json.items(): + label, flops = key_str.rsplit("|", 1) + fit_curves[(label, float(flops))] = _parse_fit_curve_coeffs(coeffs) + + # Reconstruct minima records + minima_records = [MinimaRecord(**r) for r in result_dict["minima_records"]] + scaling_fits = {k: tuple(v) for k, v in result_dict["scaling_fits"].items()} + + # Create plots + fig_isoflop = create_isoflop_plot(isoflop_df, minima_records, fit_curves) + fig_scaling = create_scaling_plot(minima_records, scaling_fits) + + # Save plots + save_plots(fig_isoflop, fig_scaling, config.output_path) + + +def _run_upload_plots_to_wandb_step(config: UploadPlotsToWandbConfig) -> None: + """Upload plots to WandB (called by ExecutorStep).""" + from marin.scaling_laws.scaling_plots import ( + create_isoflop_plot, + create_scaling_plot, + upload_plots_to_wandb, + ) + + fs, _, _ = fsspec.get_fs_token_paths(config.plots_path) + + # Load the analysis results to regenerate plots + result_path = os.path.join(config.plots_path, "isoflop_analysis_result.json") + with fs.open(result_path, "r") as f: + result_dict = json.load(f) + + # Load the dataframe + df_path = os.path.join(config.plots_path, "isoflop_df.parquet") + isoflop_df = pd.read_parquet(df_path) + + # Load fit curves and reconstruct tuple keys + fit_curves_path = os.path.join(config.plots_path, "fit_curves.json") + with fs.open(fit_curves_path, "r") as f: + fit_curves_json = json.load(f) + fit_curves: dict[tuple[str, float], tuple[float, float, float, float, float]] = {} + for key_str, coeffs in fit_curves_json.items(): + label, flops = key_str.rsplit("|", 1) + fit_curves[(label, float(flops))] = _parse_fit_curve_coeffs(coeffs) + + # Reconstruct minima records + minima_records = [MinimaRecord(**r) for r in result_dict["minima_records"]] + scaling_fits = {k: tuple(v) for k, v in result_dict["scaling_fits"].items()} + + # Create plots + fig_isoflop = create_isoflop_plot(isoflop_df, minima_records, fit_curves) + fig_scaling = create_scaling_plot(minima_records, scaling_fits) + + upload_plots_to_wandb( + fig_isoflop, + fig_scaling, + entity=config.wandb_entity, + project=config.wandb_project, + run_name=config.wandb_run_name, + ) # ---------------- Primary Export: ExecutorStep Factory ---------------- @@ -1102,41 +1214,32 @@ def isoflop_analysis_step( training_runs: Sequence[ExecutorStep | InputName], metric_key: str = DEFAULT_METRIC_KEY, label_map: dict[str, str] | None = None, - save_plots: bool = True, - upload_to_wandb: bool = True, - wandb_entity: str = WANDB_ENTITY, - wandb_project: str = f"{WANDB_PROJECT}-analysis", - wandb_run_name: str | None = None, ) -> ExecutorStep: """Create an ExecutorStep for scaling ladder analysis. - This is the primary interface for using scaling ladder analysis in a pipeline. - The step will: - 1. Wait for all training runs to complete - 2. Read eval metrics from the training runs - 3. Fit scaling laws to find compute-optimal configurations - 4. Save plots and results to the output path + This step computes scaling law fits and saves results to JSON/parquet files. + For plotting, use `isoflop_plots_step()`. For WandB upload, use + `upload_isoflop_plots_to_wandb_step()`. Args: name: Name for this executor step training_runs: Training run ExecutorSteps or InputNames to analyze metric_key: Which metric to use for loss (default: eval/paloma/c4_en/bpb) label_map: Optional mapping from experiment_name -> display label - save_plots: Whether to save HTML plots (default: True) - upload_to_wandb: Whether to upload plots to WandB (default: True) - wandb_entity: WandB entity for uploads - wandb_project: WandB project for uploads - wandb_run_name: Name for WandB run (defaults to step name) Returns: ExecutorStep configured to run the analysis Example: - >>> from marin.scaling_laws import isoflop_analysis_step + >>> from marin.scaling_laws import isoflop_analysis_step, isoflop_plots_step >>> analysis = isoflop_analysis_step( ... name="my-scaling-analysis", ... training_runs=my_training_steps, ... ) + >>> plots = isoflop_plots_step( + ... name="my-scaling-plots", + ... analysis_step=analysis, + ... ) """ run_paths = [output_path_of(run) if isinstance(run, ExecutorStep) else run for run in training_runs] @@ -1145,8 +1248,83 @@ def isoflop_analysis_step( output_path=this_output_path(), metric_key=metric_key, label_map=tuple(label_map.items()) if label_map else None, - save_plots=save_plots, - upload_to_wandb=upload_to_wandb, + ) + + return ExecutorStep( + name=name, + fn=_run_isoflop_analysis_step, + config=config, + description=f"Scaling ladder analysis for {len(training_runs)} training runs", + ) + + +def isoflop_plots_step( + name: str, + analysis_step: ExecutorStep | InputName, +) -> ExecutorStep: + """Create an ExecutorStep to generate isoflop HTML plots. + + This step reads the output from an isoflop_analysis_step and generates + HTML plots for the isoflop curves and scaling fits. + + Args: + name: Name for this executor step + analysis_step: The isoflop_analysis_step to read results from + + Returns: + ExecutorStep configured to generate plots + + Example: + >>> analysis = isoflop_analysis_step(name="analysis", training_runs=runs) + >>> plots = isoflop_plots_step(name="plots", analysis_step=analysis) + """ + analysis_path = output_path_of(analysis_step) if isinstance(analysis_step, ExecutorStep) else analysis_step + + config = IsoFlopPlotsConfig( + analysis_output_path=analysis_path, + output_path=this_output_path(), + ) + + return ExecutorStep( + name=name, + fn=_run_isoflop_plots_step, + config=config, + description="Generate isoflop HTML plots", + ) + + +def upload_isoflop_plots_to_wandb_step( + name: str, + analysis_step: ExecutorStep | InputName, + wandb_entity: str = WANDB_ENTITY, + wandb_project: str = f"{WANDB_PROJECT}-analysis", + wandb_run_name: str | None = None, +) -> ExecutorStep: + """Create an ExecutorStep to upload isoflop plots to WandB. + + This step reads the analysis results and uploads interactive plots to WandB. + + Args: + name: Name for this executor step + analysis_step: The isoflop_analysis_step to read results from + wandb_entity: WandB entity for uploads + wandb_project: WandB project for uploads + wandb_run_name: Name for WandB run (defaults to step name) + + Returns: + ExecutorStep configured to upload plots to WandB + + Example: + >>> analysis = isoflop_analysis_step(name="analysis", training_runs=runs) + >>> upload = upload_isoflop_plots_to_wandb_step( + ... name="upload-plots", + ... analysis_step=analysis, + ... ) + """ + analysis_path = output_path_of(analysis_step) if isinstance(analysis_step, ExecutorStep) else analysis_step + + config = UploadPlotsToWandbConfig( + plots_path=analysis_path, wandb_entity=wandb_entity, wandb_project=wandb_project, wandb_run_name=wandb_run_name or name, @@ -1154,9 +1332,9 @@ def isoflop_analysis_step( return ExecutorStep( name=name, - fn=_run_isoflop_analysis_step, + fn=_run_upload_plots_to_wandb_step, config=config, - description=f"Scaling ladder analysis for {len(training_runs)} training runs", + description="Upload isoflop plots to WandB", ) diff --git a/lib/marin/src/marin/scaling_laws/scaling_plots.py b/lib/marin/src/marin/scaling_laws/scaling_plots.py index a971b46a08..ceba24ee65 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_plots.py +++ b/lib/marin/src/marin/scaling_laws/scaling_plots.py @@ -32,7 +32,6 @@ import wandb - logger = logging.getLogger(__name__) # ---------------- Theme ---------------- From 6074f9ebaf92cc090aad09423cbff802b04ea3b6 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Tue, 6 Jan 2026 10:38:47 -0800 Subject: [PATCH 39/79] Use Typed Return Values --- lib/marin/src/marin/scaling_laws/__init__.py | 8 ++ .../marin/scaling_laws/isoflop_analysis.py | 128 ++++++++++++------ .../src/marin/scaling_laws/scaling_ladder.py | 5 +- .../src/marin/scaling_laws/scaling_plots.py | 13 +- tests/test_scaling_laws.py | 8 +- 5 files changed, 108 insertions(+), 54 deletions(-) diff --git a/lib/marin/src/marin/scaling_laws/__init__.py b/lib/marin/src/marin/scaling_laws/__init__.py index 79e7f48cab..caf51f8383 100644 --- a/lib/marin/src/marin/scaling_laws/__init__.py +++ b/lib/marin/src/marin/scaling_laws/__init__.py @@ -15,17 +15,21 @@ from marin.scaling_laws.isoflop_analysis import ( DEFAULT_BUDGETS, CandidateConfig, + FitScalingLawsResult, IsoFlopAnalysisConfig, IsoFlopAnalysisResult, IsoFlopPlotsConfig, IsoFlopSweepConfig, IsoFlopTrainArgs, MinimaRecord, + QuadraticFitCoeffs, + ScalingFit, UploadPlotsToWandbConfig, build_model_config, build_optimizer_config, candidate_configs, compute_transformer_params, + fit_scaling_laws, generate_isoflop_train_args, isoflop_analysis_step, isoflop_plots_step, @@ -51,12 +55,15 @@ __all__ = [ "DEFAULT_BUDGETS", "CandidateConfig", + "FitScalingLawsResult", "IsoFlopAnalysisConfig", "IsoFlopAnalysisResult", "IsoFlopPlotsConfig", "IsoFlopSweepConfig", "IsoFlopTrainArgs", "MinimaRecord", + "QuadraticFitCoeffs", + "ScalingFit", "ScalingLadderRungConfig", "ScalingLadderSuite", "UploadPlotsToWandbConfig", @@ -66,6 +73,7 @@ "compute_transformer_params", "create_isoflop_plot", "create_scaling_plot", + "fit_scaling_laws", "generate_isoflop_train_args", "isoflop_analysis_step", "isoflop_plots_step", diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index 7887637014..8c7b465565 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -55,6 +55,7 @@ import re from collections.abc import Iterator, Sequence from dataclasses import asdict, dataclass, replace +from typing import NamedTuple import fsspec import jax.numpy as jnp @@ -105,6 +106,38 @@ V5P_CORE_OPTIONS = [8, 16, 32, 128, 256, 512] +# ---------------- Typed Tuples ---------------- + + +class ScalingFit(NamedTuple): + """Scaling law fit parameters for N* ~ A * C^alpha.""" + + alpha: float + """Exponent in scaling law.""" + + A: float + """Coefficient in scaling law.""" + + +class QuadraticFitCoeffs(NamedTuple): + """Quadratic fit coefficients for loss = a * log10(tokens)^2 + b * log10(tokens) + c.""" + + a: float + """Quadratic coefficient.""" + + b: float + """Linear coefficient.""" + + c: float + """Constant term.""" + + token_min: float + """Minimum token count used for fitting.""" + + token_max: float + """Maximum token count used for fitting.""" + + # ---------------- IsoFLOP Sweep Config ---------------- @dataclass(frozen=True) class IsoFlopSweepConfig: @@ -220,6 +253,20 @@ class MinimaRecord: scaling_A: float | None = None +@dataclass +class FitScalingLawsResult: + """Result from fit_scaling_laws containing minima, scaling fits, and fit curves.""" + + minima_records: list[MinimaRecord] + """List of optimal configurations found at each (label, flops) point.""" + + scaling_fits: dict[str, ScalingFit] + """Per-label scaling fits: {label: ScalingFit} for N* ~ A * C^alpha.""" + + fit_curves: dict[tuple[str, float], QuadraticFitCoeffs] + """Quadratic fit coefficients {(label, flops): QuadraticFitCoeffs} for plotting.""" + + # ---------------- Candidate Config Generation ---------------- @@ -708,11 +755,7 @@ def objective(params): def fit_scaling_laws( df: pd.DataFrame, -) -> tuple[ - list[MinimaRecord], - dict[str, tuple[float, float]], - dict[tuple[str, float], tuple[float, float, float, float, float]], -]: +) -> FitScalingLawsResult: """ Fit scaling laws and extract optimal configurations. @@ -720,20 +763,17 @@ def fit_scaling_laws( df: DataFrame with columns: tokens, loss, flops, params, name, label Returns: - - minima_records: List of dicts with optimal config info per (label, flops) - - scaling_fits: Dict of {label: (alpha, A)} for N* ~ A * C^alpha - - fit_curves: Dict of {(label, flops): (a, b, c, token_min, token_max)} quadratic coefficients - for plotting + FitScalingLawsResult containing minima_records, scaling_fits, and fit_curves. """ if df is None or df.empty: - return [], {}, {} + return FitScalingLawsResult(minima_records=[], scaling_fits={}, fit_curves={}) datasets = list(dict.fromkeys(df["label"].tolist())) buckets = sorted(df.flops.unique()) minima_records: list[MinimaRecord] = [] - fit_curves: dict[tuple[str, float], tuple[float, float, float, float, float]] = {} + fit_curves: dict[tuple[str, float], QuadraticFitCoeffs] = {} # Fit quadratic for each (label, budget) and find minima for lab in datasets: @@ -750,7 +790,7 @@ def fit_scaling_laws( jnp.array(sub.loss.values, dtype=jnp.float64), ) # Store coefficients along with token range used for fitting - fit_curves[(lab, C)] = (a, b, c, float(tokens_array.min()), float(tokens_array.max())) + fit_curves[(lab, C)] = QuadraticFitCoeffs(a, b, c, float(tokens_array.min()), float(tokens_array.max())) if a == 0: continue @@ -776,7 +816,7 @@ def fit_scaling_laws( ) # Fit scaling law N* ~ A * C^alpha per dataset - scaling_fits: dict[str, tuple[float, float]] = {} + scaling_fits: dict[str, ScalingFit] = {} by_lab: dict[str, list[MinimaRecord]] = {} for rec in minima_records: by_lab.setdefault(rec.label, []).append(rec) @@ -793,14 +833,18 @@ def fit_scaling_laws( alpha, logA = jnp.polyfit(jnp.log10(Cs), jnp.log10(Ns), 1) A = float(10**logA) alpha = float(alpha) - scaling_fits[lab] = (alpha, A) + scaling_fits[lab] = ScalingFit(alpha, A) # Augment minima records with scaling fit params for rec in recs: rec.scaling_alpha = alpha rec.scaling_A = A - return minima_records, scaling_fits, fit_curves + return FitScalingLawsResult( + minima_records=minima_records, + scaling_fits=scaling_fits, + fit_curves=fit_curves, + ) def transform_metrics_for_isoflop( @@ -898,7 +942,7 @@ def transform_metrics_for_isoflop( def predict_optimal_config( - scaling_fits: dict[str, tuple[float, float]], + scaling_fits: dict[str, ScalingFit], target_flops: float, label: str, sweep_config: IsoFlopSweepConfig | None = None, @@ -912,7 +956,7 @@ def predict_optimal_config( 3. Selects the candidate whose token count is closest to the predicted optimal Args: - scaling_fits: Dict of {label: (alpha, A)} from scaling ladder result. + scaling_fits: Dict of {label: ScalingFit} from scaling ladder result. target_flops: Target compute budget in FLOPs. label: Dataset/experiment label to use for scaling fit. sweep_config: Optional IsoFlopSweepConfig. If None, uses defaults. @@ -951,7 +995,7 @@ def predict_optimal_config( def predict_optimal_configs_for_budgets( - scaling_fits: dict[str, tuple[float, float]], + scaling_fits: dict[str, ScalingFit], target_budgets: list[float], label: str, sweep_config: IsoFlopSweepConfig | None = None, @@ -960,7 +1004,7 @@ def predict_optimal_configs_for_budgets( """Predict optimal configs for multiple target compute budgets. Args: - scaling_fits: Dict of {label: (alpha, A)} from scaling ladder result. + scaling_fits: Dict of {label: ScalingFit} from scaling ladder result. target_budgets: List of target compute budgets in FLOPs. label: Dataset/experiment label to use for scaling fit. sweep_config: Optional IsoFlopSweepConfig. If None, uses defaults. @@ -995,8 +1039,8 @@ class IsoFlopAnalysisResult: configs: list[CandidateConfig] """List of optimal CandidateConfig for each (label, flops_budget) pair.""" - scaling_fits: dict[str, tuple[float, float]] - """Per-label scaling fits: {label: (alpha, A)} for N* ~ A * C^alpha.""" + scaling_fits: dict[str, ScalingFit] + """Per-label scaling fits: {label: ScalingFit} for N* ~ A * C^alpha.""" isoflop_df: pd.DataFrame """Transformed dataframe used for analysis.""" @@ -1004,8 +1048,8 @@ class IsoFlopAnalysisResult: minima_records: list[MinimaRecord] """Raw minima records with detailed info for each optimum.""" - fit_curves: dict[tuple[str, float], tuple[float, float, float, float, float]] - """Quadratic fit coefficients {(label, flops): (a, b, c, token_min, token_max)} for plotting.""" + fit_curves: dict[tuple[str, float], QuadraticFitCoeffs] + """Quadratic fit coefficients {(label, flops): QuadraticFitCoeffs} for plotting.""" def to_json_dict(self) -> dict: """Convert result to JSON-serializable dict (excludes DataFrame and fit_curves).""" @@ -1016,11 +1060,11 @@ def to_json_dict(self) -> dict: } -def _parse_fit_curve_coeffs(coeffs: Sequence[float]) -> tuple[float, float, float, float, float]: +def _parse_fit_curve_coeffs(coeffs: Sequence[float]) -> QuadraticFitCoeffs: if len(coeffs) != 5: raise ValueError(f"Expected 5 fit curve coefficients, got {len(coeffs)}") a, b, c, token_min, token_max = coeffs - return (float(a), float(b), float(c), float(token_min), float(token_max)) + return QuadraticFitCoeffs(float(a), float(b), float(c), float(token_min), float(token_max)) # ---------------- ExecutorStep Config ---------------- @@ -1084,20 +1128,20 @@ def _run_isoflop_analysis_step(config: IsoFlopAnalysisConfig) -> None: logger.info(f"Labels found: {isoflop_df['label'].unique().tolist()}") logger.info(f"FLOP budgets: {sorted(isoflop_df['flops'].unique())}") - minima_records, scaling_fits, fit_curves = fit_scaling_laws(isoflop_df) + fit_result = fit_scaling_laws(isoflop_df) - logger.info(f"Found {len(minima_records)} optimal configurations") - for label, (alpha, A) in scaling_fits.items(): + logger.info(f"Found {len(fit_result.minima_records)} optimal configurations") + for label, (alpha, A) in fit_result.scaling_fits.items(): logger.info(f" {label}: N* = {A:.2e} * C^{alpha:.3f}") - configs = _minima_to_candidates(minima_records) + configs = _minima_to_candidates(fit_result.minima_records) result = IsoFlopAnalysisResult( configs=configs, - scaling_fits=scaling_fits, + scaling_fits=fit_result.scaling_fits, isoflop_df=isoflop_df, - minima_records=minima_records, - fit_curves=fit_curves, + minima_records=fit_result.minima_records, + fit_curves=fit_result.fit_curves, ) fs, _, _ = fsspec.get_fs_token_paths(config.output_path) @@ -1115,7 +1159,7 @@ def _run_isoflop_analysis_step(config: IsoFlopAnalysisConfig) -> None: fit_curves_path = os.path.join(config.output_path, "fit_curves.json") # Convert tuple keys to strings for JSON serialization - fit_curves_json = {f"{label}|{flops}": list(coeffs) for (label, flops), coeffs in fit_curves.items()} + fit_curves_json = {f"{label}|{flops}": list(coeffs) for (label, flops), coeffs in result.fit_curves.items()} with fs.open(fit_curves_path, "w") as f: json.dump(fit_curves_json, f, indent=2) logger.info(f"Saved fit curves to {fit_curves_path}") @@ -1144,14 +1188,14 @@ def _run_isoflop_plots_step(config: IsoFlopPlotsConfig) -> None: fit_curves_path = os.path.join(config.analysis_output_path, "fit_curves.json") with fs.open(fit_curves_path, "r") as f: fit_curves_json = json.load(f) - fit_curves: dict[tuple[str, float], tuple[float, float, float, float, float]] = {} + fit_curves: dict[tuple[str, float], QuadraticFitCoeffs] = {} for key_str, coeffs in fit_curves_json.items(): label, flops = key_str.rsplit("|", 1) fit_curves[(label, float(flops))] = _parse_fit_curve_coeffs(coeffs) # Reconstruct minima records minima_records = [MinimaRecord(**r) for r in result_dict["minima_records"]] - scaling_fits = {k: tuple(v) for k, v in result_dict["scaling_fits"].items()} + scaling_fits = {k: ScalingFit(*v) for k, v in result_dict["scaling_fits"].items()} # Create plots fig_isoflop = create_isoflop_plot(isoflop_df, minima_records, fit_curves) @@ -1184,14 +1228,14 @@ def _run_upload_plots_to_wandb_step(config: UploadPlotsToWandbConfig) -> None: fit_curves_path = os.path.join(config.plots_path, "fit_curves.json") with fs.open(fit_curves_path, "r") as f: fit_curves_json = json.load(f) - fit_curves: dict[tuple[str, float], tuple[float, float, float, float, float]] = {} + fit_curves: dict[tuple[str, float], QuadraticFitCoeffs] = {} for key_str, coeffs in fit_curves_json.items(): label, flops = key_str.rsplit("|", 1) fit_curves[(label, float(flops))] = _parse_fit_curve_coeffs(coeffs) # Reconstruct minima records minima_records = [MinimaRecord(**r) for r in result_dict["minima_records"]] - scaling_fits = {k: tuple(v) for k, v in result_dict["scaling_fits"].items()} + scaling_fits = {k: ScalingFit(*v) for k, v in result_dict["scaling_fits"].items()} # Create plots fig_isoflop = create_isoflop_plot(isoflop_df, minima_records, fit_curves) @@ -1391,13 +1435,13 @@ def run_isoflop_analysis( logger.info(f"Transformed {len(isoflop_df)} runs for scaling ladder analysis") - minima_records, scaling_fits, fit_curves = fit_scaling_laws(isoflop_df) - configs = _minima_to_candidates(minima_records) + fit_result = fit_scaling_laws(isoflop_df) + configs = _minima_to_candidates(fit_result.minima_records) return IsoFlopAnalysisResult( configs=configs, - scaling_fits=scaling_fits, + scaling_fits=fit_result.scaling_fits, isoflop_df=isoflop_df, - minima_records=minima_records, - fit_curves=fit_curves, + minima_records=fit_result.minima_records, + fit_curves=fit_result.fit_curves, ) diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py index 81f0943bb2..3a466192b5 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -60,6 +60,7 @@ from marin.processing.tokenize.tokenize import TokenizeConfig from marin.scaling_laws.isoflop_analysis import ( IsoFlopSweepConfig, + ScalingFit, build_model_config, build_optimizer_config, isoflop_analysis_step, @@ -153,11 +154,11 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: with fs.open(result_path, "r") as f: analysis_result = json.load(f) - scaling_fits: dict[str, tuple[float, float]] = {} + scaling_fits: dict[str, ScalingFit] = {} for key, value in analysis_result["scaling_fits"].items(): if len(value) != 2: raise ValueError(f"Expected 2 scaling fit values for '{key}', got {len(value)}") - scaling_fits[key] = (float(value[0]), float(value[1])) + scaling_fits[key] = ScalingFit(float(value[0]), float(value[1])) vocab_size = get_vocab_size_for_tokenizer(config.tokenizer) diff --git a/lib/marin/src/marin/scaling_laws/scaling_plots.py b/lib/marin/src/marin/scaling_laws/scaling_plots.py index ceba24ee65..106685ad76 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_plots.py +++ b/lib/marin/src/marin/scaling_laws/scaling_plots.py @@ -28,6 +28,7 @@ import plotly.graph_objects as go import plotly.io as pio +from marin.scaling_laws.isoflop_analysis import QuadraticFitCoeffs, ScalingFit from marin.utilities.wandb_utils import WANDB_ENTITY, WANDB_PROJECT import wandb @@ -83,14 +84,14 @@ def create_isoflop_plot( df: pd.DataFrame, minima_records: list, - fit_curves: dict[tuple[str, float], tuple[float, float, float, float, float]], + fit_curves: dict[tuple[str, float], QuadraticFitCoeffs], ) -> go.Figure: """Create the IsoFLOP plot showing loss vs tokens for each compute budget. Args: df: DataFrame with columns: tokens, loss, flops, params, name, label - minima_records: List of dicts with optimal config info per (label, flops) - fit_curves: Dict of {(label, flops): (a, b, c, token_min, token_max)} quadratic fit coefficients + minima_records: List of MinimaRecord with optimal config info per (label, flops) + fit_curves: Dict of {(label, flops): QuadraticFitCoeffs} quadratic fit coefficients Returns: Plotly Figure with the isoflop visualization @@ -186,13 +187,13 @@ def create_isoflop_plot( def create_scaling_plot( minima_records: list, - scaling_fits: dict[str, tuple[float, float]], + scaling_fits: dict[str, ScalingFit], ) -> go.Figure: """Create the scaling law fit plot showing N* vs compute budget. Args: minima_records: List of MinimaRecord with optimal config info per (label, flops) - scaling_fits: Dict of {label: (alpha, A)} for N* ~ A * C^alpha + scaling_fits: Dict of {label: ScalingFit} for N* ~ A * C^alpha Returns: Plotly Figure with the scaling fit visualization @@ -316,7 +317,7 @@ def upload_plots_to_wandb( ) wandb.log( { - "isoFLOP_plot": wandb.Plotly(fig_isoflop), + "isoflop_plot": wandb.Plotly(fig_isoflop), "scaling_plot": wandb.Plotly(fig_scaling), } ) diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py index efa43046a6..279e66a488 100644 --- a/tests/test_scaling_laws.py +++ b/tests/test_scaling_laws.py @@ -442,16 +442,16 @@ def test_end_to_end_analysis_pipeline(): assert len(isoflop_df) == 6 # Fit scaling laws - minima_records, _scaling_fits, _ = fit_scaling_laws(isoflop_df) + fit_result = fit_scaling_laws(isoflop_df) # Should find two minima (one per budget: 1e18 and 1e19) - assert len(minima_records) == 2 - flops_budgets = {rec.flops for rec in minima_records} + assert len(fit_result.minima_records) == 2 + flops_budgets = {rec.flops for rec in fit_result.minima_records} assert flops_budgets == {1e18, 1e19} # Verify fitted minima are near expected optimal points # Curve fitting interpolates to find analytical minimum of fitted quadratic - minima_by_flops = {rec.flops: rec for rec in minima_records} + minima_by_flops = {rec.flops: rec for rec in fit_result.minima_records} # At 1e18: raw data optimal at 2.5B (loss=1.12), fitted minimum ~2.6B assert abs(minima_by_flops[1e18].optimal_tokens - 2.6e9) < 0.2e9 From 878d38b489934ad36170818e4bf11b0dea35e230 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Tue, 6 Jan 2026 15:28:01 -0800 Subject: [PATCH 40/79] Try to segment out the opinionated stuff --- .../exp2166_scaling_ladder_analysis.py | 64 ++++-- experiments/isoflop_sweep.py | 198 ++++++++---------- lib/marin/src/marin/scaling_laws/__init__.py | 10 + .../marin/scaling_laws/isoflop_analysis.py | 114 +++++----- .../src/marin/scaling_laws/scaling_ladder.py | 27 ++- 5 files changed, 219 insertions(+), 194 deletions(-) diff --git a/experiments/exp2166_scaling_ladder_analysis.py b/experiments/exp2166_scaling_ladder_analysis.py index cf26271a63..1c5f1c187e 100644 --- a/experiments/exp2166_scaling_ladder_analysis.py +++ b/experiments/exp2166_scaling_ladder_analysis.py @@ -24,35 +24,63 @@ The analysis steps depend on completed isoflop training runs from isoflop_sweep.py. Once complete, results are saved to the output path and uploaded to WandB. + +This experiment creates ExecutorSteps directly rather than using library factory +functions, following the pattern of isolating executor step creation to experiments. """ from experiments.defaults import default_validation_sets from experiments.isoflop_sweep import MARIN_SCALING_SUITES, nemotron_mix -from marin.execution.executor import executor_main -from marin.scaling_laws import scaling_ladder_suite +from marin.execution.executor import ExecutorStep, executor_main, output_path_of +from marin.scaling_laws import ( + IsoFlopAnalysisConfig, + ScalingLadderRungConfig, + run_isoflop_analysis_step, + run_scaling_ladder_rung, +) +from marin.scaling_laws.recipe import MARIN_2025_RECIPE -# Get training steps and datasets for each suite +# Get training steps from the isoflop sweep nemotron_training, _ = MARIN_SCALING_SUITES["nemotron"] -# --- Scaling Ladder Suites --- -# These analyze completed isoflop training runs and optionally train compute-optimal models - -# Target budgets for compute-optimal training runs (beyond the isoflop sweep) -# Set to empty list to only run analysis without training +# --- Configuration --- TARGET_BUDGETS: list[float] = [1e18, 3e18, 6e18, 1e19, 3e19, 6e19, 1e20] +EXPERIMENT_NAME = "exp2166-scaling-ladder-nemotron-validation" +LABEL = "nemo-wider-depth-adapt" - -nemotron_suite = scaling_ladder_suite( - name="exp2166-scaling-ladder-nemotron-validation", - training_runs=nemotron_training, - target_budgets=TARGET_BUDGETS, - label="nemo-wider-depth-adapt", - tokenized=nemotron_mix, - wandb_project="marin-analysis", - validation_sets=default_validation_sets(tokenizer="stanford-crfm/marin-tokenizer"), +# --- Step 1: IsoFLOP Analysis --- +# Creates scaling law fits from the training runs +analysis_step = ExecutorStep( + name=f"{EXPERIMENT_NAME}-analysis", + fn=run_isoflop_analysis_step, + config=IsoFlopAnalysisConfig( + training_runs=[output_path_of(r) for r in nemotron_training], + output_path=f"analysis/{EXPERIMENT_NAME}", + recipe=MARIN_2025_RECIPE, + ), ) -all_steps = [*nemotron_suite.all_steps] +# --- Step 2: Optimal Training Runs --- +# Train compute-optimal models at each target budget +optimal_runs: list[ExecutorStep] = [] +for budget in TARGET_BUDGETS: + step = ExecutorStep( + name=f"{EXPERIMENT_NAME}-optimal-{budget:.0e}", + fn=run_scaling_ladder_rung, + config=ScalingLadderRungConfig( + analysis_output_path=output_path_of(analysis_step), + target_budget=budget, + label=LABEL, + tokenized=nemotron_mix, + output_path=f"checkpoints/{EXPERIMENT_NAME}-optimal-{budget:.0e}", + recipe=MARIN_2025_RECIPE, + validation_sets=default_validation_sets(tokenizer="stanford-crfm/marin-tokenizer"), + ), + ) + optimal_runs.append(step) + +# All steps for this experiment +all_steps = [analysis_step, *optimal_runs] if __name__ == "__main__": executor_main(steps=all_steps) diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index c198c8d758..67c32d56d4 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -15,16 +15,22 @@ """Generate ISOFlop sweep steps for varying model sizes on a target dataset. This script constructs `ExecutorStep` objects that train models of different -sizes while keeping the total training FLOPs roughly constant. It is intended +sizes while keeping the total training FLOPs roughly constant. It is intended as a lightweight scaffold for ISOFlop scaling law experiments. + +ExecutorSteps are created directly in this experiment file, following the pattern +of isolating executor step creation to experiments. The library provides: +- `generate_isoflop_train_args()`: Computes model/optimizer configs for each sweep point +- `IsoFlopSweepConfig`: Configuration for the sweep parameters +- `ScalingRecipe`: Named hyperparameter bundle + +This file uses those to create the actual ExecutorSteps. """ import dataclasses -from dataclasses import dataclass, replace +from dataclasses import replace from levanter.data.text import LMMixtureDatasetConfig -from levanter.optim.cautious import CautiousConfig -from levanter.optim.config import OptimizerConfig from experiments.evals.evals import default_eval from experiments.evals.task_configs import EvalTaskConfig @@ -37,94 +43,73 @@ from fray.cluster import ResourceConfig from marin.execution.executor import ExecutorStep, InputName, executor_main from marin.processing.tokenize import get_vocab_size_for_tokenizer, lm_mixture_data_config -from marin.scaling_laws.isoflop_analysis import ( +from marin.scaling_laws import ( CandidateConfig, IsoFlopSweepConfig, generate_isoflop_train_args, ) +from marin.scaling_laws.recipe import MARIN_2025_RECIPE, ScalingRecipe -@dataclass(frozen=True) -class IsoFlopExperimentConfig: - """Configuration for isoflop experiments with dataset and eval settings. - - Composes an IsoFlopSweepConfig for core sweep parameters and adds - experiment-specific settings like tokenized dataset and eval tasks. - """ - - tokenized_dataset: InputName | str - """Tokenized dataset to train on.""" - - sweep_config: IsoFlopSweepConfig = dataclasses.field(default_factory=IsoFlopSweepConfig) - """Core sweep parameters (budgets, seq_len, etc.).""" - - eval_tasks: tuple[EvalTaskConfig, ...] | None = None - """Evaluation tasks to run after training (disabled by default).""" - - base_optimizer_config: OptimizerConfig = dataclasses.field( - default_factory=lambda: CautiousConfig( - learning_rate=1.0, # Placeholder - weight_decay=0.1, - min_lr_ratio=0.0, - warmup=0.1, - beta1=0.95, - beta2=0.98, - epsilon=1e-15, - max_grad_norm=1, - adamc_weight_decay=True, - lr_schedule="linear", - decay=0.2, - ), - ) - - base_train_config: SimpleTrainConfig = dataclasses.field( - default_factory=lambda: SimpleTrainConfig( - resources=ResourceConfig.with_tpu("v5p-8"), - train_batch_size=1, - num_train_steps=50_000, - learning_rate=1.0, # Placeholder - weight_decay=0.1, - min_lr_ratio=0.0, - lr_schedule="linear", - decay=0.2, - ) - ) - - -def generate_isoflop_steps( - config: IsoFlopExperimentConfig, +def create_isoflop_sweep_steps( + tokenized: InputName | str | LMMixtureDatasetConfig, experiment_name: str, + recipe: ScalingRecipe, + sweep_config: IsoFlopSweepConfig | None = None, + eval_tasks: tuple[EvalTaskConfig, ...] | None = None, ) -> tuple[list[ExecutorStep], list[CandidateConfig]]: - """Generate executor steps for an ISOFlop sweep. + """Create ExecutorSteps for an ISOFlop sweep. + + This function creates ExecutorSteps directly in experiment code, using + `generate_isoflop_train_args()` from the library to compute configs. - Uses generate_isoflop_train_args() from the scaling_laws library to get - model configs, optimizer configs, and other arguments, then constructs - ExecutorSteps using default_train(). + Args: + tokenized: Tokenized dataset to train on. + experiment_name: Name suffix for the experiment (e.g., 'nemo', 'dclm'). + recipe: ScalingRecipe with hyperparameters - must be explicitly specified. + sweep_config: Optional custom sweep config. Uses defaults with the recipe if None. + eval_tasks: Optional evaluation tasks to run after training. Returns: A tuple of: - steps: Training and evaluation ExecutorSteps for the sweep. - - candidates: CandidateConfig for each training run (contains budget, hidden_size, - num_layers, batch_size, train_steps, learning_rate, etc.) + - candidates: CandidateConfig for each training run with full config details. """ - vocab_size = get_vocab_size_for_tokenizer(config.sweep_config.tokenizer) + # Build sweep config with the specified recipe + if sweep_config is None: + sweep_config = IsoFlopSweepConfig(recipe=recipe) + else: + sweep_config = dataclasses.replace(sweep_config, recipe=recipe) + + vocab_size = get_vocab_size_for_tokenizer(sweep_config.tokenizer) - # Get training arguments from the library + # Library provides the training arguments (model configs, optimizer configs, etc.) train_args_list = generate_isoflop_train_args( - sweep_config=config.sweep_config, + sweep_config=sweep_config, experiment_name=experiment_name, vocab_size=vocab_size, - base_optimizer_config=config.base_optimizer_config, ) - train_steps_list: list[ExecutorStep] = [] + # Base config for training runs + base_train_config = SimpleTrainConfig( + resources=ResourceConfig.with_tpu("v5p-8"), + train_batch_size=1, + num_train_steps=50_000, + learning_rate=1.0, # Placeholder, will be overridden + weight_decay=recipe.weight_decay, + min_lr_ratio=recipe.min_lr_ratio, + lr_schedule=recipe.lr_schedule, + decay=recipe.decay, + ) + + train_steps: list[ExecutorStep] = [] eval_steps: list[ExecutorStep] = [] candidates: list[CandidateConfig] = [] + # Create ExecutorSteps for each candidate configuration for args in train_args_list: - # Build SimpleTrainConfig from the library-provided arguments train_cfg = replace( - config.base_train_config, + base_train_config, train_batch_size=args.candidate.batch_size, learning_rate=args.candidate.learning_rate, num_train_steps=args.candidate.train_steps, @@ -132,10 +117,10 @@ def generate_isoflop_steps( optimizer_config=args.optimizer_config, ) - # Create training step using default_train + # Create training step train_step = default_train( name=args.run_name, - tokenized=config.tokenized_dataset, + tokenized=tokenized, model_config=args.model_config, train_config=train_cfg, eval_harness_tasks=[], @@ -144,50 +129,23 @@ def generate_isoflop_steps( # Pin to static output path for checkpoint reuse train_step = train_step.with_output_path(args.output_path) - train_steps_list.append(train_step) + train_steps.append(train_step) candidates.append(args.candidate) - # Evaluation on the latest checkpoint for each ISOFlop run - if config.eval_tasks: + # Create evaluation step if eval tasks specified + if eval_tasks: eval_step = default_eval( train_step, resource_config=train_cfg.resources, - evals=config.eval_tasks, + evals=eval_tasks, ) eval_steps.append(eval_step) - all_steps: list[ExecutorStep] = [*train_steps_list, *eval_steps] + all_steps: list[ExecutorStep] = [*train_steps, *eval_steps] return all_steps, candidates -def generate_isoflop_sweep( - tokenized: InputName | ExecutorStep | LMMixtureDatasetConfig, - experiment_name: str, - sweep_config: IsoFlopSweepConfig | None = None, - eval_tasks: tuple[EvalTaskConfig, ...] | None = None, -) -> tuple[list[ExecutorStep], list[CandidateConfig]]: - """Generate an ISOFlop sweep for a tokenized dataset. - - Args: - tokenized: Tokenized dataset to train on. - experiment_name: Name suffix for the experiment (e.g., 'nemo', 'dclm'). - sweep_config: Optional custom sweep config. Uses defaults if None. - eval_tasks: Optional evaluation tasks to run after training. - - Returns: - A tuple of: - - steps: Training and evaluation ExecutorSteps for the sweep. - - candidates: CandidateConfig for each training run with full config details. - """ - config = IsoFlopExperimentConfig( - tokenized_dataset=tokenized, - sweep_config=sweep_config or IsoFlopSweepConfig(), - eval_tasks=eval_tasks, - ) - steps, candidates = generate_isoflop_steps(config, experiment_name) - - return steps, candidates - +# --- Tokenized Datasets --- dclm_tokenized = dataclasses.replace( default_tokenize( @@ -197,7 +155,6 @@ def generate_isoflop_sweep( ).with_output_path("tokenized/dclm_baseline-0206f1/"), ) - dclm_mix = lm_mixture_data_config( components={"dclm": dclm_tokenized}, weights={"dclm": 1.0}, @@ -218,14 +175,37 @@ def generate_isoflop_sweep( num_validation_sequences={"dolma3_mix-150B-1025": 1024}, ) + +# --- Scaling Suites --- +# Each suite explicitly specifies the recipe for visibility. +# ExecutorSteps are created by create_isoflop_sweep_steps() in this file. + MARIN_SCALING_SUITES = { - "nemotron": generate_isoflop_sweep(nemotron_mix, experiment_name="nemo-wider-depth-adapt"), - "common_pile": generate_isoflop_sweep(comma_main_mixture(permutation_type="linear"), experiment_name="comma-mix"), - "common_pile_feistel": generate_isoflop_sweep( - comma_main_mixture(permutation_type="feistel"), experiment_name="comma-mix-feistel" + "nemotron": create_isoflop_sweep_steps( + tokenized=nemotron_mix, + experiment_name="nemo-wider-depth-adapt", + recipe=MARIN_2025_RECIPE, + ), + "common_pile": create_isoflop_sweep_steps( + tokenized=comma_main_mixture(permutation_type="linear"), + experiment_name="comma-mix", + recipe=MARIN_2025_RECIPE, + ), + "common_pile_feistel": create_isoflop_sweep_steps( + tokenized=comma_main_mixture(permutation_type="feistel"), + experiment_name="comma-mix-feistel", + recipe=MARIN_2025_RECIPE, + ), + "dclm-default": create_isoflop_sweep_steps( + tokenized=dclm_mix, + experiment_name="dclm-default", + recipe=MARIN_2025_RECIPE, + ), + "dolma3_mix_150b": create_isoflop_sweep_steps( + tokenized=dolma3_mix, + experiment_name="dolma3-mix-150b-1025", + recipe=MARIN_2025_RECIPE, ), - "dclm-default": generate_isoflop_sweep(dclm_mix, experiment_name="dclm-default"), - "dolma3_mix_150b": generate_isoflop_sweep(dolma3_mix, experiment_name="dolma3-mix-150b-1025"), } if __name__ == "__main__": diff --git a/lib/marin/src/marin/scaling_laws/__init__.py b/lib/marin/src/marin/scaling_laws/__init__.py index caf51f8383..6e639903b5 100644 --- a/lib/marin/src/marin/scaling_laws/__init__.py +++ b/lib/marin/src/marin/scaling_laws/__init__.py @@ -37,11 +37,17 @@ predict_optimal_config, predict_optimal_configs_for_budgets, run_isoflop_analysis, + run_isoflop_analysis_step, upload_isoflop_plots_to_wandb_step, ) +from marin.scaling_laws.recipe import ( + MARIN_2025_RECIPE, + ScalingRecipe, +) from marin.scaling_laws.scaling_ladder import ( ScalingLadderRungConfig, ScalingLadderSuite, + run_scaling_ladder_rung, scaling_ladder_rung_step, scaling_ladder_suite, ) @@ -54,6 +60,7 @@ __all__ = [ "DEFAULT_BUDGETS", + "MARIN_2025_RECIPE", "CandidateConfig", "FitScalingLawsResult", "IsoFlopAnalysisConfig", @@ -66,6 +73,7 @@ "ScalingFit", "ScalingLadderRungConfig", "ScalingLadderSuite", + "ScalingRecipe", "UploadPlotsToWandbConfig", "build_model_config", "build_optimizer_config", @@ -81,6 +89,8 @@ "predict_optimal_config", "predict_optimal_configs_for_budgets", "run_isoflop_analysis", + "run_isoflop_analysis_step", + "run_scaling_ladder_rung", "save_plots", "scaling_ladder_rung_step", "scaling_ladder_suite", diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index 8c7b465565..6516ed8c34 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -74,6 +74,7 @@ extract_run_name_from_path, read_metrics_dataframe, ) +from marin.scaling_laws.recipe import MARIN_2025_RECIPE, ScalingRecipe from marin.utilities.wandb_utils import WANDB_ENTITY, WANDB_PROJECT logger = logging.getLogger(__name__) @@ -87,18 +88,6 @@ # ---------------- IsoFLOP Sweep Constants ---------------- DEFAULT_BUDGETS: tuple[float, ...] = (1e18, 3e18, 6e18, 1e19, 3e19, 6e19, 1e20) -MLP_RATIO = 4 - -# Learning rate scaling: lr = LR_CONSTANT * sqrt(batch_size) / hidden_dim -LR_CONSTANT = 0.33 - -# Head size for attention: num_heads = hidden_dim / HIDDEN_HEAD_RATIO -HIDDEN_HEAD_RATIO = 128 - -# Beta2 scaling for Adam: beta2 = BETA2_BASE ** (batch_size / BETA2_BATCH_DIVISOR) -# Reference: https://arxiv.org/pdf/2507.07101 -BETA2_BASE = 0.98 -BETA2_BATCH_DIVISOR = 128 # TPU v5p hardware constants for memory estimation HBM_PER_CHIP_GIB = 95 @@ -147,6 +136,9 @@ class IsoFlopSweepConfig: hyperparameters for isoflop experiments. """ + recipe: ScalingRecipe = MARIN_2025_RECIPE + """Scaling recipe with hyperparameters (learning rate, beta2, optimizer settings).""" + tokenizer: str = "stanford-crfm/marin-tokenizer" """Tokenizer to use (needed for vocab size).""" @@ -165,12 +157,6 @@ class IsoFlopSweepConfig: base_hidden_layer_ratio: int = 64 """Base ratio for hidden_dim to num_layers calculation.""" - hidden_head_ratio: int = HIDDEN_HEAD_RATIO - """Ratio for hidden_dim to num_heads calculation.""" - - lr_constant: float = LR_CONSTANT - """Constant for learning rate calculation: lr = (lr_constant * sqrt(batch)) / hidden_dim.""" - min_hidden_pow: int = 9 """Minimum hidden dimension as power of 2 (2^9 = 512).""" @@ -405,11 +391,12 @@ def candidate_configs( else: step_size = 128 + recipe = cfg.recipe for hidden_size in range(2**cfg.min_hidden_pow, (2**cfg.max_hidden_pow) + 1, step_size): hs_pow = math.log2(hidden_size) - intermediate_dim = hidden_size * MLP_RATIO + intermediate_dim = hidden_size * recipe.mlp_ratio num_layers = round(hidden_size / (cfg.base_hidden_layer_ratio + (hs_pow * 4) - cfg.min_hidden_pow)) - n_heads = max(1, hidden_size // cfg.hidden_head_ratio) + n_heads = max(1, hidden_size // recipe.hidden_head_ratio) n_kv_heads = n_heads batch_exact = budget / compute_total_flops( @@ -425,11 +412,11 @@ def candidate_configs( ) batch_size = round_to_power_of_two(batch_exact) - lr = (cfg.lr_constant * math.sqrt(batch_size)) / hidden_size + lr = recipe.compute_learning_rate(batch_size, hidden_size) while lr > 0.01: batch_size //= 2 - lr = (cfg.lr_constant * math.sqrt(batch_size)) / hidden_size - b2 = BETA2_BASE ** (batch_size / BETA2_BATCH_DIVISOR) + lr = recipe.compute_learning_rate(batch_size, hidden_size) + b2 = recipe.compute_beta2(batch_size) if batch_size < 8: continue @@ -545,32 +532,46 @@ def build_model_config(candidate: CandidateConfig, seq_len: int = SEQ_LEN) -> Qw ) -def build_optimizer_config(candidate: CandidateConfig) -> CautiousConfig: - """Build optimizer config from a CandidateConfig. +def build_optimizer_config( + candidate: CandidateConfig, + recipe: ScalingRecipe = MARIN_2025_RECIPE, +) -> CautiousConfig: + """Build optimizer config from a CandidateConfig and ScalingRecipe. This is the shared builder used by both generate_isoflop_train_args() and scaling_ladder's run_scaling_ladder_rung() to ensure consistent optimizer configs. + + Args: + candidate: CandidateConfig with learning_rate and beta2. + recipe: ScalingRecipe with optimizer hyperparameters. """ return CautiousConfig( learning_rate=candidate.learning_rate, - weight_decay=0.1, - min_lr_ratio=0.0, - warmup=0.1, - beta1=0.95, + weight_decay=recipe.weight_decay, + min_lr_ratio=recipe.min_lr_ratio, + warmup=recipe.warmup, + beta1=recipe.beta1, beta2=candidate.beta2, - epsilon=1e-15, - max_grad_norm=1, + epsilon=recipe.epsilon, + max_grad_norm=recipe.max_grad_norm, adamc_weight_decay=True, - lr_schedule="linear", - decay=0.2, + lr_schedule=recipe.lr_schedule, + decay=recipe.decay, ) -def _minima_to_candidates(minima_records: list[MinimaRecord]) -> list[CandidateConfig]: +def _minima_to_candidates( + minima_records: list[MinimaRecord], + recipe: ScalingRecipe = MARIN_2025_RECIPE, +) -> list[CandidateConfig]: """Convert minima records to CandidateConfig objects. - This is used by both _run_isoflop_analysis_step() and run_isoflop_analysis() + This is used by both run_isoflop_analysis_step() and run_isoflop_analysis() to convert the fitted minima into usable candidate configs. + + Args: + minima_records: List of optimal configurations from scaling law fits. + recipe: ScalingRecipe with architecture and hyperparameter settings. """ configs = [] for rec in minima_records: @@ -579,14 +580,14 @@ def _minima_to_candidates(minima_records: list[MinimaRecord]) -> list[CandidateC configs.append( CandidateConfig( hidden_size=rec.hidden_dim, - intermediate_dim=rec.hidden_dim * MLP_RATIO, + intermediate_dim=rec.hidden_dim * recipe.mlp_ratio, num_layers=rec.num_layers, - num_heads=max(1, rec.hidden_dim // HIDDEN_HEAD_RATIO), - num_kv_heads=max(1, rec.hidden_dim // HIDDEN_HEAD_RATIO), + num_heads=max(1, rec.hidden_dim // recipe.hidden_head_ratio), + num_kv_heads=max(1, rec.hidden_dim // recipe.hidden_head_ratio), batch_size=rec.batch_size, train_steps=int(rec.optimal_tokens / (rec.batch_size * SEQ_LEN)), - learning_rate=(LR_CONSTANT * math.sqrt(rec.batch_size)) / rec.hidden_dim, - beta2=BETA2_BASE ** (rec.batch_size / BETA2_BATCH_DIVISOR), + learning_rate=recipe.compute_learning_rate(rec.batch_size, rec.hidden_dim), + beta2=recipe.compute_beta2(rec.batch_size), tokens=rec.optimal_tokens, flops_budget=rec.flops, ) @@ -626,19 +627,20 @@ def generate_isoflop_train_args( ... # Use args.model_config, args.optimizer_config, etc. with default_train() ... pass """ + recipe = sweep_config.recipe if base_optimizer_config is None: base_optimizer_config = CautiousConfig( learning_rate=1.0, # Placeholder, will be overridden - weight_decay=0.1, - min_lr_ratio=0.0, - warmup=0.1, - beta1=0.95, + weight_decay=recipe.weight_decay, + min_lr_ratio=recipe.min_lr_ratio, + warmup=recipe.warmup, + beta1=recipe.beta1, beta2=0.98, # Placeholder, will be overridden - epsilon=1e-15, - max_grad_norm=1, + epsilon=recipe.epsilon, + max_grad_norm=recipe.max_grad_norm, adamc_weight_decay=True, - lr_schedule="linear", - decay=0.2, + lr_schedule=recipe.lr_schedule, + decay=recipe.decay, ) results: list[IsoFlopTrainArgs] = [] @@ -1074,6 +1076,9 @@ def _parse_fit_curve_coeffs(coeffs: Sequence[float]) -> QuadraticFitCoeffs: class IsoFlopAnalysisConfig(EvalMetricsAnalysisConfig): """Configuration for scaling ladder analysis ExecutorStep.""" + recipe: ScalingRecipe = MARIN_2025_RECIPE + """Scaling recipe for computing optimal hyperparameters.""" + metric_key: str = DEFAULT_METRIC_KEY """Metric to use for loss (default: eval/paloma/c4_en/bpb).""" @@ -1109,7 +1114,7 @@ class UploadPlotsToWandbConfig: """Name for the WandB run.""" -def _run_isoflop_analysis_step(config: IsoFlopAnalysisConfig) -> None: +def run_isoflop_analysis_step(config: IsoFlopAnalysisConfig) -> None: """Execute scaling ladder analysis (called by ExecutorStep).""" raw_df = read_metrics_dataframe(config) @@ -1134,7 +1139,7 @@ def _run_isoflop_analysis_step(config: IsoFlopAnalysisConfig) -> None: for label, (alpha, A) in fit_result.scaling_fits.items(): logger.info(f" {label}: N* = {A:.2e} * C^{alpha:.3f}") - configs = _minima_to_candidates(fit_result.minima_records) + configs = _minima_to_candidates(fit_result.minima_records, config.recipe) result = IsoFlopAnalysisResult( configs=configs, @@ -1258,6 +1263,7 @@ def isoflop_analysis_step( training_runs: Sequence[ExecutorStep | InputName], metric_key: str = DEFAULT_METRIC_KEY, label_map: dict[str, str] | None = None, + recipe: ScalingRecipe = MARIN_2025_RECIPE, ) -> ExecutorStep: """Create an ExecutorStep for scaling ladder analysis. @@ -1270,6 +1276,7 @@ def isoflop_analysis_step( training_runs: Training run ExecutorSteps or InputNames to analyze metric_key: Which metric to use for loss (default: eval/paloma/c4_en/bpb) label_map: Optional mapping from experiment_name -> display label + recipe: ScalingRecipe with hyperparameters Returns: ExecutorStep configured to run the analysis @@ -1290,13 +1297,14 @@ def isoflop_analysis_step( config = IsoFlopAnalysisConfig( training_runs=run_paths, output_path=this_output_path(), + recipe=recipe, metric_key=metric_key, label_map=tuple(label_map.items()) if label_map else None, ) return ExecutorStep( name=name, - fn=_run_isoflop_analysis_step, + fn=run_isoflop_analysis_step, config=config, description=f"Scaling ladder analysis for {len(training_runs)} training runs", ) @@ -1389,6 +1397,7 @@ def run_isoflop_analysis( training_runs: Sequence[ExecutorStep] | Sequence[str], metric_key: str = DEFAULT_METRIC_KEY, label_map: dict[str, str] | None = None, + recipe: ScalingRecipe = MARIN_2025_RECIPE, ) -> IsoFlopAnalysisResult: """Analyze isoflop training runs and return optimal training configurations. @@ -1399,6 +1408,7 @@ def run_isoflop_analysis( training_runs: List of ExecutorSteps or path strings to training runs metric_key: Which metric to use for loss (default: eval/paloma/c4_en/bpb) label_map: Optional mapping from experiment_name -> display label + recipe: ScalingRecipe with hyperparameter settings Returns: IsoFlopAnalysisResult with configs, scaling_fits, and analysis data @@ -1436,7 +1446,7 @@ def run_isoflop_analysis( logger.info(f"Transformed {len(isoflop_df)} runs for scaling ladder analysis") fit_result = fit_scaling_laws(isoflop_df) - configs = _minima_to_candidates(fit_result.minima_records) + configs = _minima_to_candidates(fit_result.minima_records, recipe) return IsoFlopAnalysisResult( configs=configs, diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py index 3a466192b5..f1ea8fd865 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -67,8 +67,8 @@ pick_v5p_type, predict_optimal_config, ) +from marin.scaling_laws.recipe import MARIN_2025_RECIPE, ScalingRecipe from marin.training.training import TrainLmOnPodConfig, run_levanter_train_lm -from marin.utilities.wandb_utils import WANDB_ENTITY, WANDB_PROJECT logger = logging.getLogger(__name__) @@ -133,6 +133,9 @@ class ScalingLadderRungConfig: output_path: str """Where to write training outputs.""" + recipe: ScalingRecipe = MARIN_2025_RECIPE + """Scaling recipe with hyperparameters.""" + tokenizer: str = "stanford-crfm/marin-tokenizer" """Tokenizer to use.""" @@ -194,7 +197,7 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: vocab_size, ) - optimizer_cfg = build_optimizer_config(candidate) + optimizer_cfg = build_optimizer_config(candidate, config.recipe) pretraining_data = _prepare_data_config(config.tokenized, config.validation_sets) @@ -247,6 +250,7 @@ def scaling_ladder_rung_step( target_budget: float, label: str, tokenized: InputName | ExecutorStep | LMMixtureDatasetConfig, + recipe: ScalingRecipe = MARIN_2025_RECIPE, tokenizer: str = "stanford-crfm/marin-tokenizer", seq_len: int = 4096, override_output_path: str | None = None, @@ -264,6 +268,7 @@ def scaling_ladder_rung_step( label: Dataset label to use for scaling fit (e.g., 'nemo', 'comma') tokenized: Tokenized dataset to train on. Can be an ExecutorStep, InputName, or LMMixtureDatasetConfig. + recipe: ScalingRecipe with hyperparameters tokenizer: Tokenizer to use seq_len: Sequence length for training override_output_path: Optional override for the output path @@ -287,6 +292,7 @@ def scaling_ladder_rung_step( label=label, tokenized=resolved_tokenized, output_path=output_path, + recipe=recipe, tokenizer=tokenizer, seq_len=seq_len, validation_sets=validation_sets, @@ -334,14 +340,11 @@ def scaling_ladder_suite( target_budgets: Sequence[float], label: str, tokenized: InputName | ExecutorStep | LMMixtureDatasetConfig, + recipe: ScalingRecipe = MARIN_2025_RECIPE, tokenizer: str = "stanford-crfm/marin-tokenizer", seq_len: int = 4096, metric_key: str = "eval/paloma/c4_en/bpb", label_map: dict[str, str] | None = None, - save_plots: bool = True, - upload_to_wandb: bool = True, - wandb_entity: str = WANDB_ENTITY, - wandb_project: str = f"{WANDB_PROJECT}-analysis", validation_sets: dict[str, TokenizerStep] | None = None, ) -> ScalingLadderSuite: """Create a complete scaling ladder: IsoFLOP analysis + optimal training runs. @@ -360,14 +363,11 @@ def scaling_ladder_suite( label: Dataset label to use for scaling fit (e.g., 'nemo', 'comma') tokenized: Tokenized dataset for optimal training runs. Can be an ExecutorStep, InputName, or LMMixtureDatasetConfig. + recipe: ScalingRecipe with hyperparameters tokenizer: Tokenizer to use seq_len: Sequence length for training metric_key: Which metric to use for loss label_map: Optional mapping from experiment_name -> display label - save_plots: Whether to save HTML plots - upload_to_wandb: Whether to upload plots to WandB - wandb_entity: WandB entity for uploads - wandb_project: WandB project for uploads validation_sets: Optional validation sets for eval loss tracking Returns: @@ -388,11 +388,7 @@ def scaling_ladder_suite( training_runs=training_runs, metric_key=metric_key, label_map=label_map, - save_plots=save_plots, - upload_to_wandb=upload_to_wandb, - wandb_entity=wandb_entity, - wandb_project=wandb_project, - wandb_run_name=f"{name}-analysis", + recipe=recipe, ) optimal_runs = [] @@ -403,6 +399,7 @@ def scaling_ladder_suite( target_budget=budget, label=label, tokenized=tokenized, + recipe=recipe, tokenizer=tokenizer, seq_len=seq_len, validation_sets=validation_sets, From 20a584c7c9d91e4e484e3112a0911b2c6c3f89d4 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Tue, 6 Jan 2026 19:02:02 -0800 Subject: [PATCH 41/79] Fix FLOP counting bug --- .../marin/scaling_laws/isoflop_analysis.py | 12 ++- lib/marin/src/marin/scaling_laws/recipe.py | 94 +++++++++++++++++++ tests/test_scaling_laws.py | 28 +++--- 3 files changed, 119 insertions(+), 15 deletions(-) create mode 100644 lib/marin/src/marin/scaling_laws/recipe.py diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index 6516ed8c34..a335b61cac 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -87,6 +87,8 @@ MARIN_TOKENIZER_VOCAB_SIZE = 128256 # ---------------- IsoFLOP Sweep Constants ---------------- +# Budgets in training FLOPs (includes 3x multiplier for forward + backward pass). +# This matches how FLOPs are tracked in WandB via Levanter's log_performance_stats. DEFAULT_BUDGETS: tuple[float, ...] = (1e18, 3e18, 6e18, 1e19, 3e19, 6e19, 1e20) # TPU v5p hardware constants for memory estimation @@ -298,7 +300,12 @@ def compute_total_flops( seq_len: int, vocab_size: int, ) -> float: - """Compute total training FLOPs using Levanter utilities.""" + """Compute total training FLOPs using Levanter utilities. + + This returns training FLOPs which includes forward pass (1x) + backward pass (2x) = 3x. + This matches the FLOP accounting in Levanter's log_performance_stats callback + (see train_lm.py) and standard ML conventions (e.g., Chinchilla paper). + """ flops_per_token = lm_flops_per_token( hidden, intermediate, @@ -309,7 +316,8 @@ def compute_total_flops( vocab_size, glu=True, ) - return flops_per_token * batch * steps * seq_len + # Multiply by 3 for training: forward (1x) + backward (2x) + return 3 * flops_per_token * batch * steps * seq_len def estimate_memory_bytes( diff --git a/lib/marin/src/marin/scaling_laws/recipe.py b/lib/marin/src/marin/scaling_laws/recipe.py new file mode 100644 index 0000000000..1f0dfe8fd1 --- /dev/null +++ b/lib/marin/src/marin/scaling_laws/recipe.py @@ -0,0 +1,94 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Scaling recipes: named hyperparameter bundles for scaling law experiments. + +A recipe makes "opinionated defaults" explicit and named, so users consciously +choose which set of hyperparameters to use rather than getting hidden defaults. + +Usage: + from marin.scaling_laws.recipe import MARIN_2025_RECIPE, ScalingRecipe + + # Use the default recipe + recipe = MARIN_2025_RECIPE + lr = recipe.compute_learning_rate(batch_size=256, hidden_dim=1024) + beta2 = recipe.compute_beta2(batch_size=256) + + # Or create a custom recipe + my_recipe = ScalingRecipe( + name="my-experiment", + lr_constant=0.25, + weight_decay=0.05, + ) +""" + +import math +from dataclasses import dataclass + + +@dataclass(frozen=True) +class ScalingRecipe: + """A named set of hyperparameters for scaling law experiments. + + The recipe controls: + - Learning rate scaling formula + - Beta2 scaling formula (for Adam) + - Optimizer hyperparameters (weight decay, warmup, etc.) + - Architecture ratios (MLP width, head size) + """ + + name: str + """Name identifying this recipe (e.g., 'marin-2025').""" + + # Learning rate scaling: lr = lr_constant * sqrt(batch_size) / hidden_dim + lr_constant: float = 0.33 + """Constant for learning rate calculation.""" + + # Beta2 scaling for Adam: beta2 = beta2_base ** (batch_size / beta2_batch_divisor) + # Reference: https://arxiv.org/pdf/2507.07101 + beta2_base: float = 0.98 + """Base for beta2 exponential scaling.""" + + beta2_batch_divisor: float = 128 + """Divisor for beta2 batch size scaling.""" + + # Optimizer hyperparameters + weight_decay: float = 0.1 + min_lr_ratio: float = 0.0 + warmup: float = 0.1 + beta1: float = 0.95 + epsilon: float = 1e-15 + max_grad_norm: float = 1.0 + lr_schedule: str = "linear" + decay: float = 0.2 + + # Architecture ratios + mlp_ratio: int = 4 + """MLP intermediate_dim = hidden_dim * mlp_ratio.""" + + hidden_head_ratio: int = 128 + """num_heads = hidden_dim / hidden_head_ratio.""" + + def compute_learning_rate(self, batch_size: int, hidden_dim: int) -> float: + """Compute learning rate from batch size and hidden dim.""" + return (self.lr_constant * math.sqrt(batch_size)) / hidden_dim + + def compute_beta2(self, batch_size: int) -> float: + """Compute beta2 from batch size.""" + return self.beta2_base ** (batch_size / self.beta2_batch_divisor) + + +# Named recipes +MARIN_2025_RECIPE = ScalingRecipe(name="marin-2025") +"""Default Marin scaling recipe based on 2025 best practices.""" diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py index 279e66a488..87dbb6eae0 100644 --- a/tests/test_scaling_laws.py +++ b/tests/test_scaling_laws.py @@ -184,9 +184,10 @@ def test_robust_quad_logx_fits_quadratic(): # --- Snapshot test for config generation --- -# Snapshot of expected output for generate_isoflop_train_args with budget=1e18. -# This captures the configuration generation logic to ensure reproducibility. -EXPECTED_ISOFLOP_CONFIGS_1E18 = [ +# Snapshot of expected output for generate_isoflop_train_args with budget=3e18 training FLOPs. +# Note: compute_total_flops includes the 3x multiplier for training (forward + backward pass), +# matching how FLOPs are tracked in WandB via Levanter's log_performance_stats. +EXPECTED_ISOFLOP_CONFIGS_3E18 = [ { "hidden_size": 512, "intermediate_dim": 2048, @@ -198,7 +199,7 @@ def test_robust_quad_logx_fits_quadratic(): "learning_rate": 0.003646, "beta2": 0.994962, "tpu_type": "v5p-8", - "run_name": "isoflop-1e+18-d512-L6-B32-test-snapshot", + "run_name": "isoflop-3e+18-d512-L6-B32-test-snapshot", }, { "hidden_size": 640, @@ -211,7 +212,7 @@ def test_robust_quad_logx_fits_quadratic(): "learning_rate": 0.002063, "beta2": 0.997478, "tpu_type": "v5p-8", - "run_name": "isoflop-1e+18-d640-L7-B16-test-snapshot", + "run_name": "isoflop-3e+18-d640-L7-B16-test-snapshot", }, { "hidden_size": 768, @@ -224,7 +225,7 @@ def test_robust_quad_logx_fits_quadratic(): "learning_rate": 0.001719, "beta2": 0.997478, "tpu_type": "v5p-8", - "run_name": "isoflop-1e+18-d768-L8-B16-test-snapshot", + "run_name": "isoflop-3e+18-d768-L8-B16-test-snapshot", }, { "hidden_size": 896, @@ -237,7 +238,7 @@ def test_robust_quad_logx_fits_quadratic(): "learning_rate": 0.001042, "beta2": 0.998738, "tpu_type": "v5p-8", - "run_name": "isoflop-1e+18-d896-L10-B8-test-snapshot", + "run_name": "isoflop-3e+18-d896-L10-B8-test-snapshot", }, { "hidden_size": 1024, @@ -250,7 +251,7 @@ def test_robust_quad_logx_fits_quadratic(): "learning_rate": 0.000912, "beta2": 0.998738, "tpu_type": "v5p-8", - "run_name": "isoflop-1e+18-d1024-L11-B8-test-snapshot", + "run_name": "isoflop-3e+18-d1024-L11-B8-test-snapshot", }, ] @@ -259,9 +260,10 @@ def test_generate_isoflop_train_args_snapshot(): """Snapshot test: verify generate_isoflop_train_args produces expected configs. This test ensures the scaling_laws module produces identical configurations - for reproducible isoflop sweeps. + for reproducible isoflop sweeps. Uses 3e18 training FLOPs budget (which accounts + for the 3x multiplier for forward + backward pass). """ - config = IsoFlopSweepConfig(budgets=(1e18,)) + config = IsoFlopSweepConfig(budgets=(3e18,)) result = generate_isoflop_train_args( sweep_config=config, experiment_name="test-snapshot", @@ -269,10 +271,10 @@ def test_generate_isoflop_train_args_snapshot(): ) assert len(result) == len( - EXPECTED_ISOFLOP_CONFIGS_1E18 - ), f"Expected {len(EXPECTED_ISOFLOP_CONFIGS_1E18)} configs, got {len(result)}" + EXPECTED_ISOFLOP_CONFIGS_3E18 + ), f"Expected {len(EXPECTED_ISOFLOP_CONFIGS_3E18)} configs, got {len(result)}" - for i, (args, expected) in enumerate(zip(result, EXPECTED_ISOFLOP_CONFIGS_1E18, strict=True)): + for i, (args, expected) in enumerate(zip(result, EXPECTED_ISOFLOP_CONFIGS_3E18, strict=True)): assert isinstance(args, IsoFlopTrainArgs) c = args.candidate actual = { From 3cc92141bfcaee46d8f32070973d064de6ddb008 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Tue, 6 Jan 2026 20:10:38 -0800 Subject: [PATCH 42/79] Try to separate opinions from main code --- lib/marin/src/marin/scaling_laws/__init__.py | 12 +- .../marin/scaling_laws/isoflop_analysis.py | 312 +++++++++--------- lib/marin/src/marin/scaling_laws/recipe.py | 113 ++++++- .../src/marin/scaling_laws/scaling_ladder.py | 14 +- tests/test_scaling_laws.py | 143 +++++--- 5 files changed, 373 insertions(+), 221 deletions(-) diff --git a/lib/marin/src/marin/scaling_laws/__init__.py b/lib/marin/src/marin/scaling_laws/__init__.py index 6e639903b5..69643ca65b 100644 --- a/lib/marin/src/marin/scaling_laws/__init__.py +++ b/lib/marin/src/marin/scaling_laws/__init__.py @@ -25,10 +25,10 @@ QuadraticFitCoeffs, ScalingFit, UploadPlotsToWandbConfig, - build_model_config, build_optimizer_config, candidate_configs, - compute_transformer_params, + candidate_to_model_config, + compute_training_flops, fit_scaling_laws, generate_isoflop_train_args, isoflop_analysis_step, @@ -38,6 +38,8 @@ predict_optimal_configs_for_budgets, run_isoflop_analysis, run_isoflop_analysis_step, + solve_for_batch_size, + solve_for_train_steps, upload_isoflop_plots_to_wandb_step, ) from marin.scaling_laws.recipe import ( @@ -75,10 +77,10 @@ "ScalingLadderSuite", "ScalingRecipe", "UploadPlotsToWandbConfig", - "build_model_config", "build_optimizer_config", "candidate_configs", - "compute_transformer_params", + "candidate_to_model_config", + "compute_training_flops", "create_isoflop_plot", "create_scaling_plot", "fit_scaling_laws", @@ -94,6 +96,8 @@ "save_plots", "scaling_ladder_rung_step", "scaling_ladder_suite", + "solve_for_batch_size", + "solve_for_train_steps", "upload_isoflop_plots_to_wandb_step", "upload_plots_to_wandb", ] diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index a335b61cac..9f52983494 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -61,9 +61,9 @@ import jax.numpy as jnp import pandas as pd from jaxopt import ScipyMinimize -from levanter.utils.flop_utils import lm_flops_per_token from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig +from levanter.models.llama import LlamaConfig from levanter.models.qwen import Qwen3Config from levanter.optim.cautious import CautiousConfig from levanter.optim.config import OptimizerConfig @@ -134,12 +134,19 @@ class QuadraticFitCoeffs(NamedTuple): class IsoFlopSweepConfig: """Configuration for generating ISOFlop sweep candidate configs. - This config controls the model architecture search space and training - hyperparameters for isoflop experiments. + This config controls the FLOP budgets and training parameters. + Architecture decisions (num_layers formula, hidden_pow bounds, etc.) + are controlled by the ScalingRecipe. """ recipe: ScalingRecipe = MARIN_2025_RECIPE - """Scaling recipe with hyperparameters (learning rate, beta2, optimizer settings).""" + """Scaling recipe with all opinionated hyperparameters: + - Architecture formula (num_layers from hidden_size) + - Architecture ratios (mlp_ratio, hidden_head_ratio) + - Search bounds (min/max hidden_pow, step_size) + - Constraints (max_learning_rate, min_batch_size) + - Optimizer settings (weight_decay, warmup, etc.) + """ tokenizer: str = "stanford-crfm/marin-tokenizer" """Tokenizer to use (needed for vocab size).""" @@ -156,15 +163,6 @@ class IsoFlopSweepConfig: flop_tolerance: float = 0.01 """Tolerance for matching FLOP budget (relative error).""" - base_hidden_layer_ratio: int = 64 - """Base ratio for hidden_dim to num_layers calculation.""" - - min_hidden_pow: int = 9 - """Minimum hidden dimension as power of 2 (2^9 = 512).""" - - max_hidden_pow: int = 12 - """Maximum hidden dimension as power of 2 (2^12 = 4096).""" - # ---------------- Candidate Config ---------------- @@ -289,35 +287,84 @@ def round_flops_to_bucket(flops: float) -> float: return float(rounded_mantissa) * (10**exponent) -def compute_total_flops( - batch: int, - num_layers: int, - hidden: int, - intermediate: int, - num_kv_heads: int, - num_heads: int, - steps: int, - seq_len: int, +def compute_training_flops( + model_config: "LlamaConfig", vocab_size: int, + batch_size: int, + train_steps: int, + seq_len: int, ) -> float: - """Compute total training FLOPs using Levanter utilities. + """Compute total training FLOPs using the model config's own method. This returns training FLOPs which includes forward pass (1x) + backward pass (2x) = 3x. This matches the FLOP accounting in Levanter's log_performance_stats callback (see train_lm.py) and standard ML conventions (e.g., Chinchilla paper). + + Args: + model_config: Levanter model config with flops_per_token method (LlamaConfig or subclass). + vocab_size: Vocabulary size. + batch_size: Training batch size. + train_steps: Number of training steps. + seq_len: Sequence length. + + Returns: + Total training FLOPs (including 3x multiplier for forward + backward pass). """ - flops_per_token = lm_flops_per_token( - hidden, - intermediate, - num_layers, - num_kv_heads, - num_heads, - seq_len, - vocab_size, - glu=True, - ) + flops_per_token = model_config.flops_per_token(vocab_size, seq_len) # Multiply by 3 for training: forward (1x) + backward (2x) - return 3 * flops_per_token * batch * steps * seq_len + return 3 * flops_per_token * batch_size * train_steps * seq_len + + +def solve_for_batch_size( + model_config: "LlamaConfig", + vocab_size: int, + target_flops: float, + train_steps: int, + seq_len: int, +) -> float: + """Solve for batch size needed to hit a target FLOP budget. + + Given: total_flops = 3 * flops_per_token * batch * steps * seq_len + Solve: batch = total_flops / (3 * flops_per_token * steps * seq_len) + + Args: + model_config: Levanter model config with flops_per_token method. + vocab_size: Vocabulary size. + target_flops: Target total training FLOPs. + train_steps: Number of training steps. + seq_len: Sequence length. + + Returns: + Exact batch size (float) - caller decides how to round. + """ + flops_per_token = model_config.flops_per_token(vocab_size, seq_len) + return target_flops / (3 * flops_per_token * train_steps * seq_len) + + +def solve_for_train_steps( + model_config: "LlamaConfig", + vocab_size: int, + target_flops: float, + batch_size: int, + seq_len: int, +) -> float: + """Solve for training steps needed to hit a target FLOP budget. + + Given: total_flops = 3 * flops_per_token * batch * steps * seq_len + Solve: steps = total_flops / (3 * flops_per_token * batch * seq_len) + + Args: + model_config: Levanter model config with flops_per_token method. + vocab_size: Vocabulary size. + target_flops: Target total training FLOPs. + batch_size: Training batch size. + seq_len: Sequence length. + + Returns: + Exact training steps (float) - caller decides how to round. + """ + flops_per_token = model_config.flops_per_token(vocab_size, seq_len) + return target_flops / (3 * flops_per_token * batch_size * seq_len) def estimate_memory_bytes( @@ -354,20 +401,32 @@ def estimate_memory_bytes( def pick_v5p_type( - param_count: int, - hidden: int, - layers: int, - batch: int, + model_config: "Qwen3Config", + vocab_size: int, + batch_size: int, seq_len: int, - vocab: int, ) -> str: """ Select the smallest TPU v5p slice that fits the model in float32. + Args: + model_config: Levanter model config with total_trainable_params method. + vocab_size: Vocabulary size. + batch_size: Training batch size. + seq_len: Sequence length. + Returns: - - TPU slice name, e.g., "v5p-8" or "v5p-32" + TPU slice name, e.g., "v5p-8" or "v5p-32" """ - need_bytes = estimate_memory_bytes(param_count, hidden, layers, batch, seq_len, vocab) + param_count = model_config.total_trainable_params(vocab_size) + need_bytes = estimate_memory_bytes( + param_count, + model_config.hidden_dim, + model_config.num_layers, + batch_size, + seq_len, + vocab_size, + ) chip_bytes = HBM_PER_CHIP_GIB * 1024**3 chips = math.ceil(need_bytes / chip_bytes) cores_req = chips * CORES_PER_CHIP @@ -386,148 +445,89 @@ def candidate_configs( ) -> Iterator[CandidateConfig]: """Yield candidate model configurations within the FLOP budget. + This function uses the recipe for all opinionated choices: + - Architecture formula (num_layers from hidden_size) + - Architecture ratios (mlp_ratio, hidden_head_ratio) + - Search bounds (min/max hidden_pow, step_size) + - Constraints (max_learning_rate, min_batch_size) + + The mechanics layer (solve_for_batch_size, solve_for_train_steps, compute_training_flops) + handles the pure FLOP math. + Args: - cfg: IsoFlopSweepConfig with search parameters + cfg: IsoFlopSweepConfig with recipe and other search parameters budget: Target FLOP budget vocab_size: Vocabulary size for the tokenizer Yields: CandidateConfig objects for each valid configuration """ - if budget > 9e18: - step_size = 256 - else: - step_size = 128 - recipe = cfg.recipe - for hidden_size in range(2**cfg.min_hidden_pow, (2**cfg.max_hidden_pow) + 1, step_size): - hs_pow = math.log2(hidden_size) - intermediate_dim = hidden_size * recipe.mlp_ratio - num_layers = round(hidden_size / (cfg.base_hidden_layer_ratio + (hs_pow * 4) - cfg.min_hidden_pow)) - n_heads = max(1, hidden_size // recipe.hidden_head_ratio) - n_kv_heads = n_heads - - batch_exact = budget / compute_total_flops( - 1, - num_layers, - hidden_size, - intermediate_dim, - n_kv_heads, - n_heads, - cfg.steps_per_run, - cfg.seq_len, - vocab_size, - ) + # RECIPE: Get search parameters + step_size = recipe.get_step_size(budget) + min_hidden = 2**recipe.min_hidden_pow + max_hidden = 2**recipe.max_hidden_pow + + for hidden_size in range(min_hidden, max_hidden + 1, step_size): + # RECIPE: Build model config (makes all architecture decisions) + model_config = recipe.build_model_config(hidden_size, cfg.seq_len) + + # MECHANICS: Solve for batch size to hit budget with target steps + batch_exact = solve_for_batch_size(model_config, vocab_size, budget, cfg.steps_per_run, cfg.seq_len) batch_size = round_to_power_of_two(batch_exact) + + # RECIPE: Apply LR constraint lr = recipe.compute_learning_rate(batch_size, hidden_size) - while lr > 0.01: + while lr > recipe.max_learning_rate: batch_size //= 2 lr = recipe.compute_learning_rate(batch_size, hidden_size) - b2 = recipe.compute_beta2(batch_size) - if batch_size < 8: + # RECIPE: Apply min batch constraint + if batch_size < recipe.min_batch_size: continue - steps_exact = budget / compute_total_flops( - batch_size, - num_layers, - hidden_size, - intermediate_dim, - n_kv_heads, - n_heads, - 1, - cfg.seq_len, - vocab_size, - ) - train_steps = round(steps_exact) - - achieved_flops = compute_total_flops( - batch_size, - num_layers, - hidden_size, - intermediate_dim, - n_kv_heads, - n_heads, - train_steps, - cfg.seq_len, - vocab_size, - ) + # MECHANICS: Solve for steps to hit budget with chosen batch + train_steps = round(solve_for_train_steps(model_config, vocab_size, budget, batch_size, cfg.seq_len)) + # MECHANICS: Verify we hit the budget within tolerance + achieved_flops = compute_training_flops(model_config, vocab_size, batch_size, train_steps, cfg.seq_len) if abs(achieved_flops - budget) / budget > cfg.flop_tolerance: continue + # RECIPE: Compute optimizer hyperparameters + beta2 = recipe.compute_beta2(batch_size) tokens = batch_size * train_steps * cfg.seq_len yield CandidateConfig( hidden_size=hidden_size, - intermediate_dim=intermediate_dim, - num_layers=num_layers, - num_heads=n_heads, - num_kv_heads=n_kv_heads, + intermediate_dim=model_config.intermediate_dim, + num_layers=model_config.num_layers, + num_heads=model_config.num_heads, + num_kv_heads=model_config.num_kv_heads, batch_size=batch_size, train_steps=train_steps, learning_rate=lr, - beta2=b2, + beta2=beta2, tokens=tokens, flops_budget=budget, ) -def compute_transformer_params( - hidden_dim: int, - intermediate_dim: int, - num_layers: int, - vocab_size: int, - num_kv_heads: int | None = None, - num_heads: int | None = None, -) -> int: - """Compute parameter count for a transformer model. - - This is a standard approximation for LLaMA-style models with: - - Embedding: vocab_size * hidden_dim - - Per layer: 4 * hidden_dim^2 (attention) + 3 * hidden_dim * intermediate_dim (MLP with GLU) - - Output: vocab_size * hidden_dim (shared with embedding in some models) - """ - # Embedding parameters - embed_params = vocab_size * hidden_dim - - # Attention parameters per layer: Q, K, V, O projections - # For GQA: Q = hidden * hidden, K = hidden * kv_dim, V = hidden * kv_dim, O = hidden * hidden - if num_kv_heads is not None and num_heads is not None: - head_dim = hidden_dim // num_heads - kv_dim = num_kv_heads * head_dim - attn_params_per_layer = ( - hidden_dim * hidden_dim # Q - + hidden_dim * kv_dim # K - + hidden_dim * kv_dim # V - + hidden_dim * hidden_dim # O - ) - else: - # Standard MHA: 4 * hidden^2 - attn_params_per_layer = 4 * hidden_dim * hidden_dim - - # MLP parameters per layer (GLU: gate, up, down) - mlp_params_per_layer = 3 * hidden_dim * intermediate_dim - - # Layer norm parameters (2 per layer + 1 final) - ln_params = (2 * num_layers + 1) * hidden_dim - - # Total - layer_params = (attn_params_per_layer + mlp_params_per_layer) * num_layers - total = embed_params + layer_params + ln_params - - return total - - # ---------------- Shared Model/Optimizer Builders ---------------- +# Note: For parameter counts, use model_config.total_trainable_params(vocab_size) +# which is defined on Levanter model configs (LlamaConfig, Qwen3Config, etc.) -def build_model_config(candidate: CandidateConfig, seq_len: int = SEQ_LEN) -> Qwen3Config: - """Build a Qwen3Config from a CandidateConfig. +def candidate_to_model_config(candidate: CandidateConfig, seq_len: int = SEQ_LEN) -> Qwen3Config: + """Convert a CandidateConfig to a Qwen3Config for training. - This is the shared builder used by both generate_isoflop_train_args() and - scaling_ladder's run_scaling_ladder_rung() to ensure consistent model configs. + This is used after candidate search to convert the selected candidate + into a model config for actual training. The architecture parameters + come directly from the candidate (which were determined by the recipe + during search). + + Note: For creating model configs during search, use recipe.build_model_config(hidden_size). """ return Qwen3Config( max_seq_len=seq_len, @@ -656,20 +656,10 @@ def generate_isoflop_train_args( for budget in sweep_config.budgets: for candidate in candidate_configs(sweep_config, budget, vocab_size): # Build model config using shared builder - model_cfg = build_model_config(candidate, sweep_config.seq_len) - - # Compute parameter count for TPU selection - param_count = model_cfg.total_trainable_params(vocab_size) - - # Pick TPU type - tpu_type = pick_v5p_type( - param_count=param_count, - hidden=candidate.hidden_size, - layers=candidate.num_layers, - batch=candidate.batch_size, - seq_len=sweep_config.seq_len, - vocab=vocab_size, - ) + model_cfg = candidate_to_model_config(candidate, sweep_config.seq_len) + + # Pick TPU type based on model config + tpu_type = pick_v5p_type(model_cfg, vocab_size, candidate.batch_size, sweep_config.seq_len) # Build optimizer config with candidate-specific LR and beta2 optimizer_cfg = replace( diff --git a/lib/marin/src/marin/scaling_laws/recipe.py b/lib/marin/src/marin/scaling_laws/recipe.py index 1f0dfe8fd1..d225d500c3 100644 --- a/lib/marin/src/marin/scaling_laws/recipe.py +++ b/lib/marin/src/marin/scaling_laws/recipe.py @@ -17,45 +17,59 @@ A recipe makes "opinionated defaults" explicit and named, so users consciously choose which set of hyperparameters to use rather than getting hidden defaults. +The recipe controls: +- Architecture formula (how to compute num_layers from hidden_size) +- Architecture ratios (MLP width, head size) +- Learning rate and optimizer hyperparameters +- Search bounds and constraints for isoflop sweeps + Usage: from marin.scaling_laws.recipe import MARIN_2025_RECIPE, ScalingRecipe # Use the default recipe recipe = MARIN_2025_RECIPE + model_config = recipe.build_model_config(hidden_size=1024, seq_len=4096) lr = recipe.compute_learning_rate(batch_size=256, hidden_dim=1024) beta2 = recipe.compute_beta2(batch_size=256) - # Or create a custom recipe + # Or create a custom recipe with different architecture formula my_recipe = ScalingRecipe( name="my-experiment", lr_constant=0.25, - weight_decay=0.05, + base_hidden_layer_ratio=48, # shallower models ) """ import math from dataclasses import dataclass +from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig +from levanter.models.qwen import Qwen3Config + @dataclass(frozen=True) class ScalingRecipe: """A named set of hyperparameters for scaling law experiments. The recipe controls: + - Architecture formula (num_layers from hidden_size) + - Architecture ratios (MLP width, head size) - Learning rate scaling formula - Beta2 scaling formula (for Adam) - Optimizer hyperparameters (weight decay, warmup, etc.) - - Architecture ratios (MLP width, head size) + - Search bounds and constraints for isoflop sweeps """ name: str """Name identifying this recipe (e.g., 'marin-2025').""" - # Learning rate scaling: lr = lr_constant * sqrt(batch_size) / hidden_dim + # --- Learning rate scaling --- + # lr = lr_constant * sqrt(batch_size) / hidden_dim lr_constant: float = 0.33 """Constant for learning rate calculation.""" - # Beta2 scaling for Adam: beta2 = beta2_base ** (batch_size / beta2_batch_divisor) + # --- Beta2 scaling for Adam --- + # beta2 = beta2_base ** (batch_size / beta2_batch_divisor) # Reference: https://arxiv.org/pdf/2507.07101 beta2_base: float = 0.98 """Base for beta2 exponential scaling.""" @@ -63,7 +77,7 @@ class ScalingRecipe: beta2_batch_divisor: float = 128 """Divisor for beta2 batch size scaling.""" - # Optimizer hyperparameters + # --- Optimizer hyperparameters --- weight_decay: float = 0.1 min_lr_ratio: float = 0.0 warmup: float = 0.1 @@ -73,13 +87,54 @@ class ScalingRecipe: lr_schedule: str = "linear" decay: float = 0.2 - # Architecture ratios + # --- Architecture ratios --- mlp_ratio: int = 4 """MLP intermediate_dim = hidden_dim * mlp_ratio.""" hidden_head_ratio: int = 128 """num_heads = hidden_dim / hidden_head_ratio.""" + # --- Architecture formula for depth-to-width scaling --- + # num_layers = round( + # hidden_size + # / ( + # base_hidden_layer_ratio + # + (log2(hidden_size) * layer_scaling_factor) + # - layer_formula_offset + # ) + # ) + base_hidden_layer_ratio: int = 64 + """Base divisor for depth-width formula.""" + + layer_scaling_factor: float = 4.0 + """Multiplier for log2(hidden_size) in depth formula.""" + + layer_formula_offset: int = 9 + """Offset (typically min_hidden_pow) in depth formula.""" + + # --- Constraints --- + max_learning_rate: float = 0.01 + """Maximum allowed learning rate (configs with higher LR are rejected).""" + + min_batch_size: int = 8 + """Minimum allowed batch size (configs with smaller batch are rejected).""" + + # --- Search bounds for isoflop sweeps --- + min_hidden_pow: int = 9 + """Minimum hidden_size as power of 2 (2^9 = 512).""" + + max_hidden_pow: int = 12 + """Maximum hidden_size as power of 2 (2^12 = 4096).""" + + small_budget_step_size: int = 128 + """Step size for hidden_size search at smaller budgets.""" + + large_budget_step_size: int = 256 + """Step size for hidden_size search at larger budgets.""" + + budget_step_threshold: float = 9e18 + """Budget threshold for switching step sizes.""" + def compute_learning_rate(self, batch_size: int, hidden_dim: int) -> float: """Compute learning rate from batch size and hidden dim.""" return (self.lr_constant * math.sqrt(batch_size)) / hidden_dim @@ -88,6 +143,50 @@ def compute_beta2(self, batch_size: int) -> float: """Compute beta2 from batch size.""" return self.beta2_base ** (batch_size / self.beta2_batch_divisor) + def compute_num_layers(self, hidden_size: int) -> int: + """Compute number of layers from hidden size using the depth-width formula. + + This is an opinionated formula for balancing model depth and width. + """ + hs_pow = math.log2(hidden_size) + return round( + hidden_size + / (self.base_hidden_layer_ratio + (hs_pow * self.layer_scaling_factor) - self.layer_formula_offset) + ) + + def get_step_size(self, budget: float) -> int: + """Get hidden_size search step size based on budget.""" + if budget > self.budget_step_threshold: + return self.large_budget_step_size + return self.small_budget_step_size + + def build_model_config(self, hidden_size: int, seq_len: int = 4096) -> Qwen3Config: + """Build a model config from hidden_size using this recipe's architecture formula. + + This is the key interface: the recipe makes all architecture decisions + and returns a fully-specified model config. + + Args: + hidden_size: Model hidden dimension. + seq_len: Maximum sequence length. + + Returns: + A Qwen3Config with architecture determined by this recipe. + """ + num_layers = self.compute_num_layers(hidden_size) + intermediate_dim = hidden_size * self.mlp_ratio + n_heads = max(1, hidden_size // self.hidden_head_ratio) + + return Qwen3Config( + hidden_dim=hidden_size, + intermediate_dim=intermediate_dim, + num_layers=num_layers, + num_heads=n_heads, + num_kv_heads=n_heads, + max_seq_len=seq_len, + rope=Llama3RotaryEmbeddingsConfig(), + ) + # Named recipes MARIN_2025_RECIPE = ScalingRecipe(name="marin-2025") diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py index f1ea8fd865..77263e1455 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -61,8 +61,8 @@ from marin.scaling_laws.isoflop_analysis import ( IsoFlopSweepConfig, ScalingFit, - build_model_config, build_optimizer_config, + candidate_to_model_config, isoflop_analysis_step, pick_v5p_type, predict_optimal_config, @@ -185,17 +185,9 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: f" learning_rate={candidate.learning_rate:.6f}, tokens={candidate.tokens:.2e}" ) - model_cfg = build_model_config(candidate, config.seq_len) + model_cfg = candidate_to_model_config(candidate, config.seq_len) - param_count = model_cfg.total_trainable_params(vocab_size) - tpu_type = pick_v5p_type( - param_count, - candidate.hidden_size, - candidate.num_layers, - candidate.batch_size, - config.seq_len, - vocab_size, - ) + tpu_type = pick_v5p_type(model_cfg, vocab_size, candidate.batch_size, config.seq_len) optimizer_cfg = build_optimizer_config(candidate, config.recipe) diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py index 87dbb6eae0..a4ea8debe4 100644 --- a/tests/test_scaling_laws.py +++ b/tests/test_scaling_laws.py @@ -22,18 +22,24 @@ import pandas as pd import pytest +from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig +from levanter.models.qwen import Qwen3Config + from marin.scaling_laws.isoflop_analysis import ( MARIN_TOKENIZER_VOCAB_SIZE, IsoFlopSweepConfig, IsoFlopTrainArgs, candidate_configs, - compute_total_flops, + candidate_to_model_config, + compute_training_flops, fit_scaling_laws, generate_isoflop_train_args, parse_isoflop_run_name, robust_quad_logx, round_flops_to_bucket, round_to_power_of_two, + solve_for_batch_size, + solve_for_train_steps, transform_metrics_for_isoflop, ) @@ -90,43 +96,106 @@ def test_round_flops_to_bucket(value, expected): # --- FLOP computation tests --- -def test_compute_total_flops_linear_in_batch_and_steps(): +def test_compute_training_flops_linear_in_batch_and_steps(): """Test that FLOPs scale linearly with batch size and steps.""" - base_flops = compute_total_flops( - batch=32, - num_layers=12, - hidden=512, - intermediate=2048, - num_kv_heads=8, + # Build a model config for testing + model_config = Qwen3Config( + max_seq_len=4096, + hidden_dim=512, + intermediate_dim=2048, num_heads=8, - steps=1000, - seq_len=4096, - vocab_size=128256, - ) - double_batch_flops = compute_total_flops( - batch=64, - num_layers=12, - hidden=512, - intermediate=2048, num_kv_heads=8, - num_heads=8, - steps=1000, - seq_len=4096, - vocab_size=128256, + num_layers=12, + rope=Llama3RotaryEmbeddingsConfig(), ) - double_steps_flops = compute_total_flops( - batch=32, + vocab_size = 128256 + seq_len = 4096 + + base_flops = compute_training_flops(model_config, vocab_size, 32, 1000, seq_len) + double_batch_flops = compute_training_flops(model_config, vocab_size, 64, 1000, seq_len) + double_steps_flops = compute_training_flops(model_config, vocab_size, 32, 2000, seq_len) + + assert abs(double_batch_flops - 2 * base_flops) / base_flops < 0.01 + assert abs(double_steps_flops - 2 * base_flops) / base_flops < 0.01 + + +def test_solve_for_batch_size_inverts_flop_calculation(): + """Test that solve_for_batch_size correctly inverts compute_training_flops.""" + model_config = Qwen3Config( + max_seq_len=4096, + hidden_dim=768, + intermediate_dim=3072, + num_heads=12, + num_kv_heads=12, num_layers=12, - hidden=512, - intermediate=2048, + rope=Llama3RotaryEmbeddingsConfig(), + ) + vocab_size = 128256 + seq_len = 4096 + train_steps = 10000 + original_batch_size = 64 + + # Compute FLOPs for known batch size + target_flops = compute_training_flops(model_config, vocab_size, original_batch_size, train_steps, seq_len) + + # Solve for batch size given those FLOPs + recovered_batch = solve_for_batch_size(model_config, vocab_size, target_flops, train_steps, seq_len) + + # Should recover original batch size (exact float) + assert abs(recovered_batch - original_batch_size) < 0.01 + + +def test_solve_for_train_steps_inverts_flop_calculation(): + """Test that solve_for_train_steps correctly inverts compute_training_flops.""" + model_config = Qwen3Config( + max_seq_len=4096, + hidden_dim=1024, + intermediate_dim=4096, + num_heads=8, num_kv_heads=8, + num_layers=16, + rope=Llama3RotaryEmbeddingsConfig(), + ) + vocab_size = 128256 + seq_len = 4096 + batch_size = 32 + original_steps = 50000 + + # Compute FLOPs for known steps + target_flops = compute_training_flops(model_config, vocab_size, batch_size, original_steps, seq_len) + + # Solve for steps given those FLOPs + recovered_steps = solve_for_train_steps(model_config, vocab_size, target_flops, batch_size, seq_len) + + # Should recover original steps (exact float) + assert abs(recovered_steps - original_steps) < 0.01 + + +def test_solvers_consistent_with_each_other(): + """Test that solving for batch and then steps gives consistent results.""" + model_config = Qwen3Config( + max_seq_len=4096, + hidden_dim=512, + intermediate_dim=2048, num_heads=8, - steps=2000, - seq_len=4096, - vocab_size=128256, + num_kv_heads=8, + num_layers=8, + rope=Llama3RotaryEmbeddingsConfig(), ) - assert abs(double_batch_flops - 2 * base_flops) / base_flops < 0.01 - assert abs(double_steps_flops - 2 * base_flops) / base_flops < 0.01 + vocab_size = 128256 + seq_len = 4096 + target_flops = 1e19 + + # Pick arbitrary steps, solve for batch + steps = 20000 + batch = solve_for_batch_size(model_config, vocab_size, target_flops, steps, seq_len) + + # Now with that batch, solve for steps - should get back original + recovered_steps = solve_for_train_steps(model_config, vocab_size, target_flops, round(batch), seq_len) + + # Allow small error from rounding batch to int + relative_error = abs(recovered_steps - steps) / steps + assert relative_error < 0.01 # --- Run name parsing tests --- @@ -150,16 +219,14 @@ def test_candidate_configs_within_tolerance(): cfg = IsoFlopSweepConfig(flop_tolerance=0.01) budget = 1e19 for candidate in candidate_configs(cfg, budget, MARIN_TOKENIZER_VOCAB_SIZE): - achieved = compute_total_flops( + # Build model config from candidate to verify FLOPs + model_config = candidate_to_model_config(candidate, cfg.seq_len) + achieved = compute_training_flops( + model_config, + MARIN_TOKENIZER_VOCAB_SIZE, candidate.batch_size, - candidate.num_layers, - candidate.hidden_size, - candidate.intermediate_dim, - candidate.num_kv_heads, - candidate.num_heads, candidate.train_steps, cfg.seq_len, - MARIN_TOKENIZER_VOCAB_SIZE, ) relative_error = abs(achieved - budget) / budget assert relative_error <= cfg.flop_tolerance @@ -185,7 +252,7 @@ def test_robust_quad_logx_fits_quadratic(): # --- Snapshot test for config generation --- # Snapshot of expected output for generate_isoflop_train_args with budget=3e18 training FLOPs. -# Note: compute_total_flops includes the 3x multiplier for training (forward + backward pass), +# Note: compute_training_flops includes the 3x multiplier for training (forward + backward pass), # matching how FLOPs are tracked in WandB via Levanter's log_performance_stats. EXPECTED_ISOFLOP_CONFIGS_3E18 = [ { From acc0ff4d44d8ef4c00c86338c328d2c1ef35727e Mon Sep 17 00:00:00 2001 From: Helw150 Date: Tue, 6 Jan 2026 20:44:35 -0800 Subject: [PATCH 43/79] Try to fix inversion --- experiments/isoflop_sweep.py | 25 +++- lib/marin/src/marin/scaling_laws/__init__.py | 8 +- .../marin/scaling_laws/isoflop_analysis.py | 139 ++++++++++++------ .../src/marin/scaling_laws/scaling_ladder.py | 46 +++++- tests/test_scaling_laws.py | 4 +- 5 files changed, 165 insertions(+), 57 deletions(-) diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index 67c32d56d4..2e783b0c70 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -31,6 +31,8 @@ from dataclasses import replace from levanter.data.text import LMMixtureDatasetConfig +from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig +from levanter.models.qwen import Qwen3Config from experiments.evals.evals import default_eval from experiments.evals.task_configs import EvalTaskConfig @@ -51,6 +53,24 @@ from marin.scaling_laws.recipe import MARIN_2025_RECIPE, ScalingRecipe +def build_qwen3_from_candidate(candidate: CandidateConfig, seq_len: int = 4096) -> Qwen3Config: + """Build a Qwen3Config from a CandidateConfig. + + This is the experiment-level helper for constructing model configs. + Different experiments can use different model types (LlamaConfig, etc.) + by implementing their own builder function. + """ + return Qwen3Config( + hidden_dim=candidate.hidden_size, + intermediate_dim=candidate.intermediate_dim, + num_layers=candidate.num_layers, + num_heads=candidate.num_heads, + num_kv_heads=candidate.num_kv_heads, + max_seq_len=seq_len, + rope=Llama3RotaryEmbeddingsConfig(), + ) + + def create_isoflop_sweep_steps( tokenized: InputName | str | LMMixtureDatasetConfig, experiment_name: str, @@ -108,6 +128,9 @@ def create_isoflop_sweep_steps( # Create ExecutorSteps for each candidate configuration for args in train_args_list: + # Build model config from candidate (experiment controls model type) + model_config = build_qwen3_from_candidate(args.candidate, sweep_config.seq_len) + train_cfg = replace( base_train_config, train_batch_size=args.candidate.batch_size, @@ -121,7 +144,7 @@ def create_isoflop_sweep_steps( train_step = default_train( name=args.run_name, tokenized=tokenized, - model_config=args.model_config, + model_config=model_config, train_config=train_cfg, eval_harness_tasks=[], tags=args.tags, diff --git a/lib/marin/src/marin/scaling_laws/__init__.py b/lib/marin/src/marin/scaling_laws/__init__.py index 69643ca65b..116b073ce0 100644 --- a/lib/marin/src/marin/scaling_laws/__init__.py +++ b/lib/marin/src/marin/scaling_laws/__init__.py @@ -27,8 +27,8 @@ UploadPlotsToWandbConfig, build_optimizer_config, candidate_configs, - candidate_to_model_config, compute_training_flops, + compute_transformer_params, fit_scaling_laws, generate_isoflop_train_args, isoflop_analysis_step, @@ -47,8 +47,10 @@ ScalingRecipe, ) from marin.scaling_laws.scaling_ladder import ( + ModelBuilder, ScalingLadderRungConfig, ScalingLadderSuite, + default_model_builder, run_scaling_ladder_rung, scaling_ladder_rung_step, scaling_ladder_suite, @@ -71,6 +73,7 @@ "IsoFlopSweepConfig", "IsoFlopTrainArgs", "MinimaRecord", + "ModelBuilder", "QuadraticFitCoeffs", "ScalingFit", "ScalingLadderRungConfig", @@ -79,10 +82,11 @@ "UploadPlotsToWandbConfig", "build_optimizer_config", "candidate_configs", - "candidate_to_model_config", "compute_training_flops", + "compute_transformer_params", "create_isoflop_plot", "create_scaling_plot", + "default_model_builder", "fit_scaling_laws", "generate_isoflop_train_args", "isoflop_analysis_step", diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index 9f52983494..c46435897e 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -62,9 +62,7 @@ import pandas as pd from jaxopt import ScipyMinimize -from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig from levanter.models.llama import LlamaConfig -from levanter.models.qwen import Qwen3Config from levanter.optim.cautious import CautiousConfig from levanter.optim.config import OptimizerConfig @@ -193,17 +191,27 @@ class CandidateConfig: class IsoFlopTrainArgs: """Arguments needed to set up an isoflop training run. - This dataclass contains all the information needed to call default_train() - for an isoflop sweep run. The caller is responsible for constructing the - experiment-specific SimpleTrainConfig from these arguments. + This dataclass contains the parameters needed for training. The caller is + responsible for constructing the model config from candidate parameters, + allowing flexibility in model type (Qwen3Config, LlamaConfig, etc.). + + Example: + >>> args = generate_isoflop_train_args(config, "my-exp", vocab_size)[0] + >>> # Caller constructs the model config + >>> model_config = Qwen3Config( + ... hidden_dim=args.candidate.hidden_size, + ... intermediate_dim=args.candidate.intermediate_dim, + ... num_layers=args.candidate.num_layers, + ... num_heads=args.candidate.num_heads, + ... num_kv_heads=args.candidate.num_kv_heads, + ... max_seq_len=4096, + ... rope=Llama3RotaryEmbeddingsConfig(), + ... ) """ candidate: CandidateConfig """The candidate configuration with model/training hyperparameters.""" - model_config: Qwen3Config - """Levanter model configuration ready to use.""" - optimizer_config: OptimizerConfig """Levanter optimizer configuration with learning_rate and beta2 set.""" @@ -367,6 +375,56 @@ def solve_for_train_steps( return target_flops / (3 * flops_per_token * batch_size * seq_len) +def compute_transformer_params( + hidden_dim: int, + intermediate_dim: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + vocab_size: int, + tie_embeddings: bool = False, +) -> int: + """Compute parameter count for a standard transformer (Llama/Qwen architecture). + + This matches the formula used in Levanter's LlamaConfig.total_trainable_params(), + allowing parameter estimation without constructing a model config. + + Args: + hidden_dim: Model hidden dimension. + intermediate_dim: MLP intermediate dimension. + num_layers: Number of transformer layers. + num_heads: Number of attention heads. + num_kv_heads: Number of key-value heads (for GQA). + vocab_size: Vocabulary size. + tie_embeddings: Whether embeddings are tied (default False). + + Returns: + Total parameter count. + """ + token_embedding = vocab_size * hidden_dim + head_size = hidden_dim // num_heads + + # Attention: Q, K, V projections + output projection + q_proj = hidden_dim * head_size * num_heads + kv_proj = 2 * hidden_dim * head_size * num_kv_heads + o_proj = head_size * num_heads * hidden_dim + attn = q_proj + kv_proj + o_proj + + # MLP: gate, up, down projections (SwiGLU uses 3 matrices) + mlp = 3 * hidden_dim * intermediate_dim + + # Per-layer: attention + mlp + 2 RMSNorm + transformer_layer = attn + mlp + 2 * hidden_dim + + # Full transformer: layers + final RMSNorm + transformer = num_layers * transformer_layer + hidden_dim + + # LM head (separate unless tied) + lm_head = 0 if tie_embeddings else token_embedding + + return transformer + token_embedding + lm_head + + def estimate_memory_bytes( param_count: int, hidden_dim: int, @@ -401,29 +459,34 @@ def estimate_memory_bytes( def pick_v5p_type( - model_config: "Qwen3Config", + candidate: CandidateConfig, vocab_size: int, - batch_size: int, seq_len: int, ) -> str: """ Select the smallest TPU v5p slice that fits the model in float32. Args: - model_config: Levanter model config with total_trainable_params method. + candidate: CandidateConfig with model architecture parameters. vocab_size: Vocabulary size. - batch_size: Training batch size. seq_len: Sequence length. Returns: TPU slice name, e.g., "v5p-8" or "v5p-32" """ - param_count = model_config.total_trainable_params(vocab_size) + param_count = compute_transformer_params( + hidden_dim=candidate.hidden_size, + intermediate_dim=candidate.intermediate_dim, + num_layers=candidate.num_layers, + num_heads=candidate.num_heads, + num_kv_heads=candidate.num_kv_heads, + vocab_size=vocab_size, + ) need_bytes = estimate_memory_bytes( param_count, - model_config.hidden_dim, - model_config.num_layers, - batch_size, + candidate.hidden_size, + candidate.num_layers, + candidate.batch_size, seq_len, vocab_size, ) @@ -514,30 +577,7 @@ def candidate_configs( ) -# ---------------- Shared Model/Optimizer Builders ---------------- -# Note: For parameter counts, use model_config.total_trainable_params(vocab_size) -# which is defined on Levanter model configs (LlamaConfig, Qwen3Config, etc.) - - -def candidate_to_model_config(candidate: CandidateConfig, seq_len: int = SEQ_LEN) -> Qwen3Config: - """Convert a CandidateConfig to a Qwen3Config for training. - - This is used after candidate search to convert the selected candidate - into a model config for actual training. The architecture parameters - come directly from the candidate (which were determined by the recipe - during search). - - Note: For creating model configs during search, use recipe.build_model_config(hidden_size). - """ - return Qwen3Config( - max_seq_len=seq_len, - hidden_dim=candidate.hidden_size, - intermediate_dim=candidate.intermediate_dim, - num_heads=candidate.num_heads, - num_kv_heads=candidate.num_kv_heads, - num_layers=candidate.num_layers, - rope=Llama3RotaryEmbeddingsConfig(), - ) +# ---------------- Shared Builders ---------------- def build_optimizer_config( @@ -629,11 +669,18 @@ def generate_isoflop_train_args( Example: >>> from marin.scaling_laws import IsoFlopSweepConfig, generate_isoflop_train_args + >>> from levanter.models.qwen import Qwen3Config >>> config = IsoFlopSweepConfig(budgets=(1e18, 1e19)) >>> train_args = generate_isoflop_train_args(config, "my-experiment", vocab_size=128256) >>> for args in train_args: - ... # Use args.model_config, args.optimizer_config, etc. with default_train() - ... pass + ... # Caller constructs the model config from candidate parameters + ... model_config = Qwen3Config( + ... hidden_dim=args.candidate.hidden_size, + ... intermediate_dim=args.candidate.intermediate_dim, + ... num_layers=args.candidate.num_layers, + ... # ... etc + ... ) + ... # Then use model_config with default_train() """ recipe = sweep_config.recipe if base_optimizer_config is None: @@ -655,11 +702,8 @@ def generate_isoflop_train_args( for budget in sweep_config.budgets: for candidate in candidate_configs(sweep_config, budget, vocab_size): - # Build model config using shared builder - model_cfg = candidate_to_model_config(candidate, sweep_config.seq_len) - - # Pick TPU type based on model config - tpu_type = pick_v5p_type(model_cfg, vocab_size, candidate.batch_size, sweep_config.seq_len) + # Pick TPU type based on candidate parameters + tpu_type = pick_v5p_type(candidate, vocab_size, sweep_config.seq_len) # Build optimizer config with candidate-specific LR and beta2 optimizer_cfg = replace( @@ -689,7 +733,6 @@ def generate_isoflop_train_args( results.append( IsoFlopTrainArgs( candidate=candidate, - model_config=model_cfg, optimizer_config=optimizer_cfg, tpu_type=tpu_type, run_name=run_name, diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py index 77263e1455..20e6199ca7 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -39,9 +39,10 @@ import json import logging import os -from collections.abc import Sequence +from collections.abc import Callable, Sequence from dataclasses import dataclass from datetime import timedelta +from typing import Any import fsspec import jmp @@ -49,7 +50,10 @@ from haliax.partitioning import ResourceAxis from levanter.checkpoint import CheckpointerConfig from levanter.data.text import LMMixtureDatasetConfig +from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig from levanter.main.train_lm import TrainLmConfig +from levanter.models.lm_model import LmConfig +from levanter.models.qwen import Qwen3Config from levanter.tracker.wandb import WandbConfig from levanter.trainer import TrainerConfig from levanter.utils.mesh import MeshConfig @@ -59,10 +63,10 @@ from marin.processing.tokenize.data_configs import add_validation_sets_to_mixture, lm_data_config from marin.processing.tokenize.tokenize import TokenizeConfig from marin.scaling_laws.isoflop_analysis import ( + CandidateConfig, IsoFlopSweepConfig, ScalingFit, build_optimizer_config, - candidate_to_model_config, isoflop_analysis_step, pick_v5p_type, predict_optimal_config, @@ -75,6 +79,27 @@ # Type alias for tokenizer steps TokenizerStep = ExecutorStep[TokenizeConfig] +# Type alias for model builder callbacks +# Takes (candidate, seq_len) and returns a model config +ModelBuilder = Callable[[CandidateConfig, int], LmConfig] + + +def default_model_builder(candidate: CandidateConfig, seq_len: int) -> Qwen3Config: + """Default model builder that creates Qwen3Config. + + This is provided as a convenience for the common case. Users can pass + their own model_builder function to use different model types. + """ + return Qwen3Config( + hidden_dim=candidate.hidden_size, + intermediate_dim=candidate.intermediate_dim, + num_layers=candidate.num_layers, + num_heads=candidate.num_heads, + num_kv_heads=candidate.num_kv_heads, + max_seq_len=seq_len, + rope=Llama3RotaryEmbeddingsConfig(), + ) + def _prepare_data_config( tokenized: InputName | str | LMMixtureDatasetConfig, @@ -133,6 +158,9 @@ class ScalingLadderRungConfig: output_path: str """Where to write training outputs.""" + model_builder: ModelBuilder | None = None + """Function to build model config from CandidateConfig. If None, uses default_model_builder (Qwen3).""" + recipe: ScalingRecipe = MARIN_2025_RECIPE """Scaling recipe with hyperparameters.""" @@ -185,9 +213,11 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: f" learning_rate={candidate.learning_rate:.6f}, tokens={candidate.tokens:.2e}" ) - model_cfg = candidate_to_model_config(candidate, config.seq_len) + # Use provided model builder or default to Qwen3 + model_builder = config.model_builder or default_model_builder + model_cfg = model_builder(candidate, config.seq_len) - tpu_type = pick_v5p_type(model_cfg, vocab_size, candidate.batch_size, config.seq_len) + tpu_type = pick_v5p_type(candidate, vocab_size, config.seq_len) optimizer_cfg = build_optimizer_config(candidate, config.recipe) @@ -242,6 +272,7 @@ def scaling_ladder_rung_step( target_budget: float, label: str, tokenized: InputName | ExecutorStep | LMMixtureDatasetConfig, + model_builder: ModelBuilder | None = None, recipe: ScalingRecipe = MARIN_2025_RECIPE, tokenizer: str = "stanford-crfm/marin-tokenizer", seq_len: int = 4096, @@ -260,6 +291,8 @@ def scaling_ladder_rung_step( label: Dataset label to use for scaling fit (e.g., 'nemo', 'comma') tokenized: Tokenized dataset to train on. Can be an ExecutorStep, InputName, or LMMixtureDatasetConfig. + model_builder: Function to build model config from CandidateConfig. + If None, uses default_model_builder (Qwen3Config). recipe: ScalingRecipe with hyperparameters tokenizer: Tokenizer to use seq_len: Sequence length for training @@ -284,6 +317,7 @@ def scaling_ladder_rung_step( label=label, tokenized=resolved_tokenized, output_path=output_path, + model_builder=model_builder, recipe=recipe, tokenizer=tokenizer, seq_len=seq_len, @@ -332,6 +366,7 @@ def scaling_ladder_suite( target_budgets: Sequence[float], label: str, tokenized: InputName | ExecutorStep | LMMixtureDatasetConfig, + model_builder: ModelBuilder | None = None, recipe: ScalingRecipe = MARIN_2025_RECIPE, tokenizer: str = "stanford-crfm/marin-tokenizer", seq_len: int = 4096, @@ -355,6 +390,8 @@ def scaling_ladder_suite( label: Dataset label to use for scaling fit (e.g., 'nemo', 'comma') tokenized: Tokenized dataset for optimal training runs. Can be an ExecutorStep, InputName, or LMMixtureDatasetConfig. + model_builder: Function to build model config from CandidateConfig. + If None, uses default_model_builder (Qwen3Config). recipe: ScalingRecipe with hyperparameters tokenizer: Tokenizer to use seq_len: Sequence length for training @@ -391,6 +428,7 @@ def scaling_ladder_suite( target_budget=budget, label=label, tokenized=tokenized, + model_builder=model_builder, recipe=recipe, tokenizer=tokenizer, seq_len=seq_len, diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py index a4ea8debe4..7f9324c7d7 100644 --- a/tests/test_scaling_laws.py +++ b/tests/test_scaling_laws.py @@ -25,12 +25,12 @@ from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig from levanter.models.qwen import Qwen3Config +from marin.scaling_laws import default_model_builder from marin.scaling_laws.isoflop_analysis import ( MARIN_TOKENIZER_VOCAB_SIZE, IsoFlopSweepConfig, IsoFlopTrainArgs, candidate_configs, - candidate_to_model_config, compute_training_flops, fit_scaling_laws, generate_isoflop_train_args, @@ -220,7 +220,7 @@ def test_candidate_configs_within_tolerance(): budget = 1e19 for candidate in candidate_configs(cfg, budget, MARIN_TOKENIZER_VOCAB_SIZE): # Build model config from candidate to verify FLOPs - model_config = candidate_to_model_config(candidate, cfg.seq_len) + model_config = default_model_builder(candidate, cfg.seq_len) achieved = compute_training_flops( model_config, MARIN_TOKENIZER_VOCAB_SIZE, From a0ecf86f1f66a4ee7b4e15101181fdccbc8e023a Mon Sep 17 00:00:00 2001 From: Helw150 Date: Tue, 6 Jan 2026 20:44:57 -0800 Subject: [PATCH 44/79] Lint --- lib/marin/src/marin/scaling_laws/scaling_ladder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py index 20e6199ca7..07461061ed 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -42,7 +42,6 @@ from collections.abc import Callable, Sequence from dataclasses import dataclass from datetime import timedelta -from typing import Any import fsspec import jmp From 90c2b3810fe63950850cdf80bc151184319fd94b Mon Sep 17 00:00:00 2001 From: Helw150 Date: Tue, 6 Jan 2026 21:11:44 -0800 Subject: [PATCH 45/79] Move the Optimizer stuff outside of the unopinionated section --- lib/marin/src/marin/scaling_laws/__init__.py | 2 -- .../marin/scaling_laws/isoflop_analysis.py | 31 ------------------- lib/marin/src/marin/scaling_laws/recipe.py | 28 +++++++++++++++++ .../src/marin/scaling_laws/scaling_ladder.py | 3 +- 4 files changed, 29 insertions(+), 35 deletions(-) diff --git a/lib/marin/src/marin/scaling_laws/__init__.py b/lib/marin/src/marin/scaling_laws/__init__.py index 116b073ce0..bb94bf3553 100644 --- a/lib/marin/src/marin/scaling_laws/__init__.py +++ b/lib/marin/src/marin/scaling_laws/__init__.py @@ -25,7 +25,6 @@ QuadraticFitCoeffs, ScalingFit, UploadPlotsToWandbConfig, - build_optimizer_config, candidate_configs, compute_training_flops, compute_transformer_params, @@ -80,7 +79,6 @@ "ScalingLadderSuite", "ScalingRecipe", "UploadPlotsToWandbConfig", - "build_optimizer_config", "candidate_configs", "compute_training_flops", "compute_transformer_params", diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index c46435897e..337019bed3 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -577,37 +577,6 @@ def candidate_configs( ) -# ---------------- Shared Builders ---------------- - - -def build_optimizer_config( - candidate: CandidateConfig, - recipe: ScalingRecipe = MARIN_2025_RECIPE, -) -> CautiousConfig: - """Build optimizer config from a CandidateConfig and ScalingRecipe. - - This is the shared builder used by both generate_isoflop_train_args() and - scaling_ladder's run_scaling_ladder_rung() to ensure consistent optimizer configs. - - Args: - candidate: CandidateConfig with learning_rate and beta2. - recipe: ScalingRecipe with optimizer hyperparameters. - """ - return CautiousConfig( - learning_rate=candidate.learning_rate, - weight_decay=recipe.weight_decay, - min_lr_ratio=recipe.min_lr_ratio, - warmup=recipe.warmup, - beta1=recipe.beta1, - beta2=candidate.beta2, - epsilon=recipe.epsilon, - max_grad_norm=recipe.max_grad_norm, - adamc_weight_decay=True, - lr_schedule=recipe.lr_schedule, - decay=recipe.decay, - ) - - def _minima_to_candidates( minima_records: list[MinimaRecord], recipe: ScalingRecipe = MARIN_2025_RECIPE, diff --git a/lib/marin/src/marin/scaling_laws/recipe.py b/lib/marin/src/marin/scaling_laws/recipe.py index d225d500c3..560bcbe685 100644 --- a/lib/marin/src/marin/scaling_laws/recipe.py +++ b/lib/marin/src/marin/scaling_laws/recipe.py @@ -45,6 +45,7 @@ from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig from levanter.models.qwen import Qwen3Config +from levanter.optim.cautious import CautiousConfig @dataclass(frozen=True) @@ -187,6 +188,33 @@ def build_model_config(self, hidden_size: int, seq_len: int = 4096) -> Qwen3Conf rope=Llama3RotaryEmbeddingsConfig(), ) + def build_optimizer_config(self, learning_rate: float, beta2: float) -> CautiousConfig: + """Build optimizer config using this recipe's hyperparameters. + + This centralizes all optimizer configuration in the recipe, ensuring + consistent hyperparameters across isoflop sweeps and optimal training runs. + + Args: + learning_rate: Learning rate (typically from CandidateConfig). + beta2: Adam beta2 (typically from CandidateConfig). + + Returns: + A CautiousConfig with optimizer settings from this recipe. + """ + return CautiousConfig( + learning_rate=learning_rate, + weight_decay=self.weight_decay, + min_lr_ratio=self.min_lr_ratio, + warmup=self.warmup, + beta1=self.beta1, + beta2=beta2, + epsilon=self.epsilon, + max_grad_norm=self.max_grad_norm, + adamc_weight_decay=True, + lr_schedule=self.lr_schedule, + decay=self.decay, + ) + # Named recipes MARIN_2025_RECIPE = ScalingRecipe(name="marin-2025") diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py index 07461061ed..9fb330704c 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -65,7 +65,6 @@ CandidateConfig, IsoFlopSweepConfig, ScalingFit, - build_optimizer_config, isoflop_analysis_step, pick_v5p_type, predict_optimal_config, @@ -218,7 +217,7 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: tpu_type = pick_v5p_type(candidate, vocab_size, config.seq_len) - optimizer_cfg = build_optimizer_config(candidate, config.recipe) + optimizer_cfg = config.recipe.build_optimizer_config(candidate.learning_rate, candidate.beta2) pretraining_data = _prepare_data_config(config.tokenized, config.validation_sets) From 1e833004a7b8ed3b173a75e4b948ef429882d07d Mon Sep 17 00:00:00 2001 From: Helw150 Date: Tue, 6 Jan 2026 21:17:30 -0800 Subject: [PATCH 46/79] Fix Mismatch Now --- experiments/exp1600_perpcorr.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/experiments/exp1600_perpcorr.py b/experiments/exp1600_perpcorr.py index 1ab8d39ee6..8e91d7a191 100644 --- a/experiments/exp1600_perpcorr.py +++ b/experiments/exp1600_perpcorr.py @@ -24,7 +24,7 @@ from experiments.evals.evals import evaluate_levanter_lm_evaluation_harness from experiments.evals.task_configs import EvalTaskConfig -from experiments.isoflop_sweep import generate_isoflop_sweep +from experiments.isoflop_sweep import create_isoflop_sweep_steps from experiments.llama import llama3_tokenizer from experiments.models import ModelConfig as HFModelConfig, download_model_step from experiments.paloma import paloma_tokenized @@ -34,6 +34,7 @@ from marin.evaluation.log_probs import default_lm_log_probs from marin.execution.executor import executor_main, output_path_of from marin.processing.tokenize.data_configs import mixture_for_evaluation +from marin.scaling_laws.recipe import MARIN_2025_RECIPE # Import shared components from exp1600_uncheatable_evals from experiments.evals.exp1600_uncheatable_evals import ( @@ -56,22 +57,22 @@ @lru_cache(maxsize=1) def build_steps(): steps = [] - isoflop_steps, isoflop_metadatas = generate_isoflop_sweep( + isoflop_steps, isoflop_candidates = create_isoflop_sweep_steps( nemotron_mix, experiment_name="nemo-wider-depth-adapt", + recipe=MARIN_2025_RECIPE, ) - for isoflop_step, isoflop_metadata in zip(isoflop_steps, isoflop_metadatas, strict=False): + for isoflop_step, candidate in zip(isoflop_steps, isoflop_candidates, strict=False): experiment_name = isoflop_step.name.split("/")[-1] paloma_tokenized_dict = paloma_tokenized(tokenizer=llama3_tokenizer) uncheatable_eval_tokenized_dict = uncheatable_eval_tokenized(tokenizer=llama3_tokenizer) eval_data = mixture_for_evaluation(paloma_tokenized_dict | uncheatable_eval_tokenized_dict) - budget, hidden_size, num_layers, batch_size, train_steps = isoflop_metadata wandb_tags = [ - f"FLOPs={budget:.1e}", - f"d={hidden_size}", - f"L={num_layers}", - f"B={batch_size}", - f"steps={train_steps}", + f"FLOPs={candidate.flops_budget:.1e}", + f"d={candidate.hidden_size}", + f"L={candidate.num_layers}", + f"B={candidate.batch_size}", + f"steps={candidate.train_steps}", ] model_config = isoflop_step.config.train_config.model checkpoint_path = output_path_of(isoflop_step) From 86aedf10b2eb76e9fdad6fa08aca5f4ed870b5bd Mon Sep 17 00:00:00 2001 From: Helw150 Date: Wed, 7 Jan 2026 14:52:52 -0800 Subject: [PATCH 47/79] Tmp --- lib/marin/src/marin/scaling_laws/__init__.py | 8 +- .../marin/scaling_laws/eval_metrics_reader.py | 22 ++- .../marin/scaling_laws/isoflop_analysis.py | 179 ++++++++---------- .../src/marin/scaling_laws/scaling_ladder.py | 2 +- .../src/marin/scaling_laws/scaling_plots.py | 4 + lib/marin/src/marin/scaling_laws/tpu_utils.py | 131 +++++++++++++ 6 files changed, 245 insertions(+), 101 deletions(-) create mode 100644 lib/marin/src/marin/scaling_laws/tpu_utils.py diff --git a/lib/marin/src/marin/scaling_laws/__init__.py b/lib/marin/src/marin/scaling_laws/__init__.py index bb94bf3553..049dd3d5aa 100644 --- a/lib/marin/src/marin/scaling_laws/__init__.py +++ b/lib/marin/src/marin/scaling_laws/__init__.py @@ -14,6 +14,7 @@ from marin.scaling_laws.isoflop_analysis import ( DEFAULT_BUDGETS, + DEFAULT_EVAL_METRIC_KEY, CandidateConfig, FitScalingLawsResult, IsoFlopAnalysisConfig, @@ -32,7 +33,6 @@ generate_isoflop_train_args, isoflop_analysis_step, isoflop_plots_step, - pick_v5p_type, predict_optimal_config, predict_optimal_configs_for_budgets, run_isoflop_analysis, @@ -41,6 +41,10 @@ solve_for_train_steps, upload_isoflop_plots_to_wandb_step, ) +from marin.scaling_laws.tpu_utils import ( + estimate_memory_bytes, + pick_v5p_type, +) from marin.scaling_laws.recipe import ( MARIN_2025_RECIPE, ScalingRecipe, @@ -63,6 +67,7 @@ __all__ = [ "DEFAULT_BUDGETS", + "DEFAULT_EVAL_METRIC_KEY", "MARIN_2025_RECIPE", "CandidateConfig", "FitScalingLawsResult", @@ -85,6 +90,7 @@ "create_isoflop_plot", "create_scaling_plot", "default_model_builder", + "estimate_memory_bytes", "fit_scaling_laws", "generate_isoflop_train_args", "isoflop_analysis_step", diff --git a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py index 17b92d232e..e95f189dbc 100644 --- a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py +++ b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py @@ -21,8 +21,8 @@ (like IsoFlop) should subclass EvalMetricsAnalysisConfig. """ -import logging import json +import logging import os from dataclasses import dataclass from collections.abc import Sequence @@ -30,6 +30,7 @@ import fsspec import pandas as pd +from marin.execution.executor import InputName from marin.utilities.wandb_utils import WANDB_ENTITY, WANDB_PROJECT try: @@ -108,8 +109,8 @@ class EvalMetricsAnalysisConfig: The training_runs field creates blocking dependencies on the training jobs. """ - training_runs: Sequence[str] - """List of training run output paths to read eval metrics from (blocks until complete).""" + training_runs: Sequence[str | InputName] + """List of training run output paths (strings or InputNames resolved by the executor).""" output_path: str """Where to write analysis outputs.""" @@ -124,6 +125,20 @@ class EvalMetricsAnalysisConfig: """WandB entity/project to query for backfill (format: 'entity/project').""" +def _resolve_training_run_path(run_path: str | InputName) -> str: + """Return a concrete path string for a training run.""" + if isinstance(run_path, InputName): + if run_path.step is not None: + raise ValueError( + "Training runs must be concrete paths when calling read_metrics_dataframe directly. " + "Use the Executor to resolve InputName values." + ) + if run_path.name is None: + raise ValueError("Training run InputName must include a name.") + return run_path.name + return run_path + + def read_metrics_dataframe(config: EvalMetricsAnalysisConfig) -> pd.DataFrame: """ Read eval metrics from training runs into a DataFrame. @@ -140,6 +155,7 @@ def read_metrics_dataframe(config: EvalMetricsAnalysisConfig) -> pd.DataFrame: all_records = [] for i, run_path in enumerate(config.training_runs): + run_path = _resolve_training_run_path(run_path) metrics_file = os.path.join(run_path, config.metrics_filename) fs, _, _ = fsspec.get_fs_token_paths(metrics_file) diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index 337019bed3..a159d10d88 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -73,12 +73,19 @@ read_metrics_dataframe, ) from marin.scaling_laws.recipe import MARIN_2025_RECIPE, ScalingRecipe +from marin.scaling_laws.tpu_utils import ( + pick_v5p_type, +) from marin.utilities.wandb_utils import WANDB_ENTITY, WANDB_PROJECT logger = logging.getLogger(__name__) # ---------------- Constants ---------------- -DEFAULT_METRIC_KEY = "eval/paloma/c4_en/bpb" + +# Paloma is a standard LLM evaluation benchmark. C4-en BPB (bits-per-byte) is a +# common loss metric that measures model perplexity on the C4 English dataset. +# See: https://arxiv.org/abs/2312.10523 +DEFAULT_EVAL_METRIC_KEY = "eval/paloma/c4_en/bpb" SEQ_LEN = 4096 # Marin tokenizer vocab size (stanford-crfm/marin-tokenizer) @@ -89,10 +96,16 @@ # This matches how FLOPs are tracked in WandB via Levanter's log_performance_stats. DEFAULT_BUDGETS: tuple[float, ...] = (1e18, 3e18, 6e18, 1e19, 3e19, 6e19, 1e20) -# TPU v5p hardware constants for memory estimation -HBM_PER_CHIP_GIB = 95 -CORES_PER_CHIP = 2 -V5P_CORE_OPTIONS = [8, 16, 32, 128, 256, 512] +# Derived from Kaiyue's hyperparameter sweep: optimal_LR * hidden_size * sqrt(batch_size) +LR_CONSTANT = 0.33 + +# ---------------- WandB Metric Keys ---------------- +# These keys correspond to the metrics logged by Levanter's training callbacks. +THROUGHPUT_TOKENS_KEY = "throughput/total_tokens" +THROUGHPUT_GFLOPS_KEY = "throughput/total_gflops" +PARAMETER_COUNT_KEY = "parameter_count" +MODEL_CONFIG_KEY = "model" +TRAINER_CONFIG_KEY = "trainer" # ---------------- Typed Tuples ---------------- @@ -156,7 +169,13 @@ class IsoFlopSweepConfig: """Sequence length for training.""" steps_per_run: int = 2**16 - """Target number of training steps per run.""" + """Number of training steps used for FLOP budget calculation and hyperparameter tuning. + + This is the reference step count that other hyperparameters (LR, beta2) are tuned for. + The actual training steps may differ based on batch size to hit the target FLOP budget. + Default of 2^16 (65,536) steps is used because the LR_CONSTANT and other tuned values + were optimized for this step count. + """ flop_tolerance: float = 0.01 """Tolerance for matching FLOP budget (relative error).""" @@ -276,6 +295,11 @@ def round_flops_to_bucket(flops: float) -> float: This ensures runs with slightly different achieved FLOPs are grouped together for analysis when they were targeting the same budget. + Using 1 significant figure creates buckets at 1e19, 2e19, 3e19, etc., + which matches the typical spacing of isoflop budget targets. + + Note: This means 1.5e19 and 2.4e19 both map to 2e19. For finer granularity, + consider using 2 significant figures (round to nearest 0.1 mantissa). Examples: 1.05e19 → 1e19 @@ -425,82 +449,6 @@ def compute_transformer_params( return transformer + token_embedding + lm_head -def estimate_memory_bytes( - param_count: int, - hidden_dim: int, - num_layers: int, - batch: int, - seq_len: int, - vocab: int, - optim_mult: int = 3, - dtype_size: int = 4, - fudge_factor: float = 2, -) -> int: - """ - Estimate float32 memory usage (in bytes) for one training step. - - Parameters: - - param_count: number of model parameters - - hidden_dim: model hidden size - - num_layers: number of Transformer layers - - batch, seq_len: training batch size and sequence length - - vocab: vocabulary size - - optim_mult: optimizer memory multiplier (e.g., 3x for Adam + states) - - dtype_size: bytes per float (4 for float32) - - fudge_factor: safety margin for extra memory - - Returns: - - total estimated memory in bytes - """ - param_bytes = param_count * optim_mult * dtype_size - act_bytes = (batch * seq_len) * ((hidden_dim * num_layers) + vocab * fudge_factor) - total_bytes = param_bytes + act_bytes - return int(total_bytes * fudge_factor) - - -def pick_v5p_type( - candidate: CandidateConfig, - vocab_size: int, - seq_len: int, -) -> str: - """ - Select the smallest TPU v5p slice that fits the model in float32. - - Args: - candidate: CandidateConfig with model architecture parameters. - vocab_size: Vocabulary size. - seq_len: Sequence length. - - Returns: - TPU slice name, e.g., "v5p-8" or "v5p-32" - """ - param_count = compute_transformer_params( - hidden_dim=candidate.hidden_size, - intermediate_dim=candidate.intermediate_dim, - num_layers=candidate.num_layers, - num_heads=candidate.num_heads, - num_kv_heads=candidate.num_kv_heads, - vocab_size=vocab_size, - ) - need_bytes = estimate_memory_bytes( - param_count, - candidate.hidden_size, - candidate.num_layers, - candidate.batch_size, - seq_len, - vocab_size, - ) - chip_bytes = HBM_PER_CHIP_GIB * 1024**3 - chips = math.ceil(need_bytes / chip_bytes) - cores_req = chips * CORES_PER_CHIP - - valid = [c for c in V5P_CORE_OPTIONS if c >= cores_req] - if not valid: - raise ValueError(f"Model too large for available v5p slices (need {cores_req} cores).") - - return f"v5p-{min(valid)}" - - def candidate_configs( cfg: IsoFlopSweepConfig, budget: float, @@ -716,6 +664,21 @@ def generate_isoflop_train_args( # ---------------- Helpers ---------------- +def _resolve_run_paths(runs: Sequence[ExecutorStep | InputName | str]) -> list[InputName | str]: + """Convert mixed ExecutorStep/InputName/path inputs to executor-ready paths. + + This helper reduces duplication across functions that accept either + ExecutorSteps or string paths. + + Args: + runs: Sequence of ExecutorStep, InputName, or path strings. + + Returns: + List of InputName or string paths. + """ + return [output_path_of(run) if isinstance(run, ExecutorStep) else run for run in runs] + + def parse_isoflop_run_name(run_name: str) -> str | None: """Parse experiment name from isoflop run name. @@ -740,7 +703,20 @@ def parse_isoflop_run_name(run_name: str) -> str | None: def robust_quad_logx(x: jnp.ndarray, y: jnp.ndarray, delta: float = 1.0) -> tuple[float, float, float]: """Fit a robust quadratic in log10(x) space using Huber loss. - Returns (a, b, c) coefficients for: loss = a * log10(x)^2 + b * log10(x) + c + Log10 space is used because FLOP budgets and token counts span many orders of + magnitude (e.g., 1e18 to 1e21+). Fitting in linear space would be numerically + unstable and dominated by the largest values. Log space provides better + conditioning and more interpretable coefficients. + + The Huber loss provides robustness to outliers compared to standard least squares. + + Args: + x: Input array (e.g., token counts). Must be positive. + y: Output array (e.g., loss values). + delta: Huber loss threshold. Residuals larger than delta use linear loss. + + Returns: + Tuple (a, b, c) of coefficients for: loss = a * log10(x)^2 + b * log10(x) + c """ L = jnp.log10(x) @@ -891,19 +867,19 @@ def transform_metrics_for_isoflop( # Extract config and summary dicts config = row.get("config", {}) or {} summary = row.get("summary", {}) or {} - model_config = config.get("model", {}) or {} - trainer_config = config.get("trainer", {}) or {} + model_config = config.get(MODEL_CONFIG_KEY, {}) or {} + trainer_config = config.get(TRAINER_CONFIG_KEY, {}) or {} # Get tokens directly from summary - tokens = summary.get("throughput/total_tokens") + tokens = summary.get(THROUGHPUT_TOKENS_KEY) if tokens is None or pd.isna(tokens): - logger.warning(f"Missing throughput/total_tokens in summary for run {run_name}") + logger.warning(f"Missing {THROUGHPUT_TOKENS_KEY} in summary for run {run_name}") continue # Get total FLOPs from summary (convert GFLOPs to FLOPs) - total_gflops = summary.get("throughput/total_gflops") + total_gflops = summary.get(THROUGHPUT_GFLOPS_KEY) if total_gflops is None or pd.isna(total_gflops): - logger.warning(f"Missing throughput/total_gflops in summary for run {run_name}") + logger.warning(f"Missing {THROUGHPUT_GFLOPS_KEY} in summary for run {run_name}") continue flops = round_flops_to_bucket(total_gflops * 1e9) @@ -917,7 +893,7 @@ def transform_metrics_for_isoflop( continue # Get parameter count from summary - params = summary.get("parameter_count") + params = summary.get(PARAMETER_COUNT_KEY) if params is None or pd.isna(params): params = None @@ -962,6 +938,17 @@ def predict_optimal_config( ) -> CandidateConfig | None: """Predict optimal training config for a target compute budget using fitted scaling laws. + This implements IsoFLOP Approach 2 from the Chinchilla paper: + 1. D_opt (optimal tokens) is found empirically at each compute budget by fitting + parabolas to actual loss values and finding the minimum. + 2. D_opt ~ A * C^alpha is fitted from those empirical minima. + 3. Given D_opt and C, N_opt (optimal params) is derived as C/(6D), so no + separate alpha fit for params is needed. + + This approach works regardless of whether the scaling exponents for params + and tokens are equal (alpha == beta), unlike Approach 3 which fits a + parametric loss surface. + This function: 1. Uses the scaling fit (N* ~ A * C^alpha) to predict optimal tokens for target_flops 2. Generates candidate configs for the target budget using candidate_configs() @@ -1089,8 +1076,8 @@ class IsoFlopAnalysisConfig(EvalMetricsAnalysisConfig): recipe: ScalingRecipe = MARIN_2025_RECIPE """Scaling recipe for computing optimal hyperparameters.""" - metric_key: str = DEFAULT_METRIC_KEY - """Metric to use for loss (default: eval/paloma/c4_en/bpb).""" + metric_key: str = DEFAULT_EVAL_METRIC_KEY + """Metric to use for loss (default: eval/paloma/c4_en/bpb - Paloma benchmark on C4 English).""" label_map: tuple[tuple[str, str], ...] | None = None """Optional mapping from experiment_name -> display label as tuple of pairs.""" @@ -1271,7 +1258,7 @@ def _run_upload_plots_to_wandb_step(config: UploadPlotsToWandbConfig) -> None: def isoflop_analysis_step( name: str, training_runs: Sequence[ExecutorStep | InputName], - metric_key: str = DEFAULT_METRIC_KEY, + metric_key: str = DEFAULT_EVAL_METRIC_KEY, label_map: dict[str, str] | None = None, recipe: ScalingRecipe = MARIN_2025_RECIPE, ) -> ExecutorStep: @@ -1302,7 +1289,7 @@ def isoflop_analysis_step( ... analysis_step=analysis, ... ) """ - run_paths = [output_path_of(run) if isinstance(run, ExecutorStep) else run for run in training_runs] + run_paths = _resolve_run_paths(training_runs) config = IsoFlopAnalysisConfig( training_runs=run_paths, @@ -1405,7 +1392,7 @@ def upload_isoflop_plots_to_wandb_step( def run_isoflop_analysis( training_runs: Sequence[ExecutorStep] | Sequence[str], - metric_key: str = DEFAULT_METRIC_KEY, + metric_key: str = DEFAULT_EVAL_METRIC_KEY, label_map: dict[str, str] | None = None, recipe: ScalingRecipe = MARIN_2025_RECIPE, ) -> IsoFlopAnalysisResult: @@ -1423,7 +1410,7 @@ def run_isoflop_analysis( Returns: IsoFlopAnalysisResult with configs, scaling_fits, and analysis data """ - run_paths = [output_path_of(run) if isinstance(run, ExecutorStep) else run for run in training_runs] + run_paths = _resolve_run_paths(training_runs) config = EvalMetricsAnalysisConfig( training_runs=run_paths, diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py index 9fb330704c..3457e6499d 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -421,7 +421,7 @@ def scaling_ladder_suite( optimal_runs = [] for budget in target_budgets: run_step = scaling_ladder_rung_step( - name=f"{name}-optimal-{budget:.0e}", + name=f"{name}-optimal-{budget:.2e}", analysis_step=analysis, target_budget=budget, label=label, diff --git a/lib/marin/src/marin/scaling_laws/scaling_plots.py b/lib/marin/src/marin/scaling_laws/scaling_plots.py index 106685ad76..6680d7307e 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_plots.py +++ b/lib/marin/src/marin/scaling_laws/scaling_plots.py @@ -306,6 +306,10 @@ def upload_plots_to_wandb( entity: WandB entity project: WandB project run_name: Name for the WandB run + + TODO: Consider extracting a generic wandb-upload utility that takes artifacts + and handles upload logic. This would decouple the plotting logic from WandB + and allow reuse across other analysis tools. """ wandb.login() run = wandb.init( diff --git a/lib/marin/src/marin/scaling_laws/tpu_utils.py b/lib/marin/src/marin/scaling_laws/tpu_utils.py new file mode 100644 index 0000000000..5b645066de --- /dev/null +++ b/lib/marin/src/marin/scaling_laws/tpu_utils.py @@ -0,0 +1,131 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TPU hardware utilities for memory estimation and slice selection. + +This module provides utilities for estimating memory requirements and +selecting appropriate TPU slice sizes for training runs. +""" + +import math +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from marin.scaling_laws.isoflop_analysis import CandidateConfig + +# ---------------- TPU v5p Hardware Constants ---------------- +# These constants are specific to TPU v5p pods. + +HBM_PER_CHIP_GIB = 95 +"""High-bandwidth memory per TPU v5p chip in GiB.""" + +CORES_PER_CHIP = 2 +"""Number of cores per TPU v5p chip.""" + +V5P_CORE_OPTIONS = [8, 16, 32, 128, 256, 512] +"""Available TPU v5p core configurations (slice sizes).""" + + +def estimate_memory_bytes( + param_count: int, + hidden_dim: int, + num_layers: int, + batch: int, + seq_len: int, + vocab: int, + optim_mult: int = 3, + dtype_size: int = 4, + fudge_factor: float = 2, +) -> int: + """Estimate float32 memory usage (in bytes) for one training step. + + This is a conservative estimate for LLaMA-style architectures with + Adam optimizer. The fudge_factor provides a safety margin for + additional memory overhead not captured in the simple model. + + Args: + param_count: Number of model parameters. + hidden_dim: Model hidden size. + num_layers: Number of Transformer layers. + batch: Training batch size. + seq_len: Sequence length. + vocab: Vocabulary size. + optim_mult: Optimizer memory multiplier (default 3 for Adam with + momentum and variance states). + dtype_size: Bytes per float (default 4 for float32). + fudge_factor: Safety margin multiplier (default 2x). + + Returns: + Estimated total memory in bytes. + + Note: + This assumes a LLaMA-style architecture with Adam optimizer in float32. + Actual memory usage may vary based on specific model architecture, + optimizer choice, and JAX/XLA memory optimizations. + """ + param_bytes = param_count * optim_mult * dtype_size + act_bytes = (batch * seq_len) * ((hidden_dim * num_layers) + vocab * fudge_factor) + total_bytes = param_bytes + act_bytes + return int(total_bytes * fudge_factor) + + +def pick_v5p_type( + candidate: "CandidateConfig", + vocab_size: int, + seq_len: int, +) -> str: + """Select the smallest TPU v5p slice that fits the model in float32. + + Uses conservative memory estimation to select a TPU slice size that + can accommodate the model parameters, optimizer states, and activations. + + Args: + candidate: CandidateConfig with model architecture parameters. + vocab_size: Vocabulary size. + seq_len: Sequence length. + + Returns: + TPU slice name, e.g., "v5p-8" or "v5p-32". + + Raises: + ValueError: If the model is too large for available v5p slices. + """ + # Import here to avoid circular dependency + from marin.scaling_laws.isoflop_analysis import compute_transformer_params + + param_count = compute_transformer_params( + hidden_dim=candidate.hidden_size, + intermediate_dim=candidate.intermediate_dim, + num_layers=candidate.num_layers, + num_heads=candidate.num_heads, + num_kv_heads=candidate.num_kv_heads, + vocab_size=vocab_size, + ) + need_bytes = estimate_memory_bytes( + param_count, + candidate.hidden_size, + candidate.num_layers, + candidate.batch_size, + seq_len, + vocab_size, + ) + chip_bytes = HBM_PER_CHIP_GIB * 1024**3 + chips = math.ceil(need_bytes / chip_bytes) + cores_req = chips * CORES_PER_CHIP + + valid = [c for c in V5P_CORE_OPTIONS if c >= cores_req] + if not valid: + raise ValueError(f"Model too large for available v5p slices (need {cores_req} cores).") + + return f"v5p-{min(valid)}" From c8a2d13e533da1e87e77d4cb042b10b2daa3e8fd Mon Sep 17 00:00:00 2001 From: Helw150 Date: Wed, 7 Jan 2026 16:46:31 -0800 Subject: [PATCH 48/79] Try to separate concerns to experiments more --- experiments/exp1600_perpcorr.py | 3 +- .../exp2166_scaling_ladder_analysis.py | 3 +- experiments/isoflop_sweep.py | 60 +- lib/marin/src/marin/scaling_laws/__init__.py | 29 +- .../marin/scaling_laws/eval_metrics_reader.py | 20 +- .../marin/scaling_laws/isoflop_analysis.py | 631 +++--------------- lib/marin/src/marin/scaling_laws/recipe.py | 336 ++++++++-- .../src/marin/scaling_laws/scaling_ladder.py | 251 +------ 8 files changed, 475 insertions(+), 858 deletions(-) diff --git a/experiments/exp1600_perpcorr.py b/experiments/exp1600_perpcorr.py index 8e91d7a191..c0cc140def 100644 --- a/experiments/exp1600_perpcorr.py +++ b/experiments/exp1600_perpcorr.py @@ -24,7 +24,7 @@ from experiments.evals.evals import evaluate_levanter_lm_evaluation_harness from experiments.evals.task_configs import EvalTaskConfig -from experiments.isoflop_sweep import create_isoflop_sweep_steps +from experiments.isoflop_sweep import MARIN_2025_RECIPE, create_isoflop_sweep_steps from experiments.llama import llama3_tokenizer from experiments.models import ModelConfig as HFModelConfig, download_model_step from experiments.paloma import paloma_tokenized @@ -34,7 +34,6 @@ from marin.evaluation.log_probs import default_lm_log_probs from marin.execution.executor import executor_main, output_path_of from marin.processing.tokenize.data_configs import mixture_for_evaluation -from marin.scaling_laws.recipe import MARIN_2025_RECIPE # Import shared components from exp1600_uncheatable_evals from experiments.evals.exp1600_uncheatable_evals import ( diff --git a/experiments/exp2166_scaling_ladder_analysis.py b/experiments/exp2166_scaling_ladder_analysis.py index 1c5f1c187e..e0f784e9da 100644 --- a/experiments/exp2166_scaling_ladder_analysis.py +++ b/experiments/exp2166_scaling_ladder_analysis.py @@ -30,7 +30,7 @@ """ from experiments.defaults import default_validation_sets -from experiments.isoflop_sweep import MARIN_SCALING_SUITES, nemotron_mix +from experiments.isoflop_sweep import MARIN_2025_RECIPE, MARIN_SCALING_SUITES, nemotron_mix from marin.execution.executor import ExecutorStep, executor_main, output_path_of from marin.scaling_laws import ( IsoFlopAnalysisConfig, @@ -38,7 +38,6 @@ run_isoflop_analysis_step, run_scaling_ladder_rung, ) -from marin.scaling_laws.recipe import MARIN_2025_RECIPE # Get training steps from the isoflop sweep nemotron_training, _ = MARIN_SCALING_SUITES["nemotron"] diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index 2e783b0c70..65b82b67dc 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -21,13 +21,11 @@ ExecutorSteps are created directly in this experiment file, following the pattern of isolating executor step creation to experiments. The library provides: - `generate_isoflop_train_args()`: Computes model/optimizer configs for each sweep point -- `IsoFlopSweepConfig`: Configuration for the sweep parameters -- `ScalingRecipe`: Named hyperparameter bundle +- `ScalingRecipe`: Named hyperparameter bundle with architecture and optimizer settings This file uses those to create the actual ExecutorSteps. """ -import dataclasses from dataclasses import replace from levanter.data.text import LMMixtureDatasetConfig @@ -46,11 +44,19 @@ from marin.execution.executor import ExecutorStep, InputName, executor_main from marin.processing.tokenize import get_vocab_size_for_tokenizer, lm_mixture_data_config from marin.scaling_laws import ( + DEFAULT_BUDGETS, CandidateConfig, - IsoFlopSweepConfig, generate_isoflop_train_args, ) -from marin.scaling_laws.recipe import MARIN_2025_RECIPE, ScalingRecipe +from marin.scaling_laws import ScalingRecipe + + +# --- Scaling Recipe --- +# This recipe encapsulates all model-specific hyperparameters for Marin scaling experiments. +# Other experiments can define their own recipes by instantiating ScalingRecipe with different values. + +MARIN_2025_RECIPE = ScalingRecipe(name="marin-2025") +"""Default Marin scaling recipe based on 2025 best practices.""" def build_qwen3_from_candidate(candidate: CandidateConfig, seq_len: int = 4096) -> Qwen3Config: @@ -75,8 +81,10 @@ def create_isoflop_sweep_steps( tokenized: InputName | str | LMMixtureDatasetConfig, experiment_name: str, recipe: ScalingRecipe, - sweep_config: IsoFlopSweepConfig | None = None, + budgets: tuple[float, ...] = DEFAULT_BUDGETS, + tokenizer: str = "stanford-crfm/marin-tokenizer", eval_tasks: tuple[EvalTaskConfig, ...] | None = None, + seq_len: int = 4096, ) -> tuple[list[ExecutorStep], list[CandidateConfig]]: """Create ExecutorSteps for an ISOFlop sweep. @@ -87,7 +95,8 @@ def create_isoflop_sweep_steps( tokenized: Tokenized dataset to train on. experiment_name: Name suffix for the experiment (e.g., 'nemo', 'dclm'). recipe: ScalingRecipe with hyperparameters - must be explicitly specified. - sweep_config: Optional custom sweep config. Uses defaults with the recipe if None. + budgets: FLOP budgets to sweep over. + tokenizer: Tokenizer to use for vocab size. eval_tasks: Optional evaluation tasks to run after training. Returns: @@ -95,19 +104,14 @@ def create_isoflop_sweep_steps( - steps: Training and evaluation ExecutorSteps for the sweep. - candidates: CandidateConfig for each training run with full config details. """ - # Build sweep config with the specified recipe - if sweep_config is None: - sweep_config = IsoFlopSweepConfig(recipe=recipe) - else: - sweep_config = dataclasses.replace(sweep_config, recipe=recipe) - - vocab_size = get_vocab_size_for_tokenizer(sweep_config.tokenizer) + vocab_size = get_vocab_size_for_tokenizer(tokenizer) # Library provides the training arguments (model configs, optimizer configs, etc.) train_args_list = generate_isoflop_train_args( - sweep_config=sweep_config, + budgets=budgets, experiment_name=experiment_name, vocab_size=vocab_size, + recipe=recipe, ) # Base config for training runs @@ -129,7 +133,7 @@ def create_isoflop_sweep_steps( # Create ExecutorSteps for each candidate configuration for args in train_args_list: # Build model config from candidate (experiment controls model type) - model_config = build_qwen3_from_candidate(args.candidate, sweep_config.seq_len) + model_config = build_qwen3_from_candidate(args.candidate, seq_len) train_cfg = replace( base_train_config, @@ -170,13 +174,11 @@ def create_isoflop_sweep_steps( # --- Tokenized Datasets --- -dclm_tokenized = dataclasses.replace( - default_tokenize( - name="dclm_baseline", - dataset=downloads["dclm_baseline"], - tokenizer=llama3_tokenizer, - ).with_output_path("tokenized/dclm_baseline-0206f1/"), -) +dclm_tokenized = default_tokenize( + name="dclm_baseline", + dataset=downloads["dclm_baseline"], + tokenizer=llama3_tokenizer, +).with_output_path("tokenized/dclm_baseline-0206f1/") dclm_mix = lm_mixture_data_config( components={"dclm": dclm_tokenized}, @@ -184,13 +186,11 @@ def create_isoflop_sweep_steps( num_validation_sequences={"dclm": 1024}, ) -dolma3_mix_tokenized = dataclasses.replace( - default_tokenize( - name="dolma3_mix-150B-1025", - dataset=downloads["dolma3_mix_150b_1025"], - tokenizer=llama3_tokenizer, - ).with_output_path("tokenized/dolma3_mix-150B-1025-15d04ee/"), -) +dolma3_mix_tokenized = default_tokenize( + name="dolma3_mix-150B-1025", + dataset=downloads["dolma3_mix_150b_1025"], + tokenizer=llama3_tokenizer, +).with_output_path("tokenized/dolma3_mix-150B-1025-15d04ee/") dolma3_mix = lm_mixture_data_config( components={"dolma3_mix-150B-1025": dolma3_mix_tokenized}, diff --git a/lib/marin/src/marin/scaling_laws/__init__.py b/lib/marin/src/marin/scaling_laws/__init__.py index 049dd3d5aa..c58ecd2f7e 100644 --- a/lib/marin/src/marin/scaling_laws/__init__.py +++ b/lib/marin/src/marin/scaling_laws/__init__.py @@ -15,48 +15,41 @@ from marin.scaling_laws.isoflop_analysis import ( DEFAULT_BUDGETS, DEFAULT_EVAL_METRIC_KEY, + DEFAULT_FLOP_TOLERANCE, + DEFAULT_SEQ_LEN, + DEFAULT_STEPS_PER_RUN, CandidateConfig, FitScalingLawsResult, IsoFlopAnalysisConfig, IsoFlopAnalysisResult, - IsoFlopPlotsConfig, - IsoFlopSweepConfig, IsoFlopTrainArgs, MinimaRecord, QuadraticFitCoeffs, ScalingFit, - UploadPlotsToWandbConfig, candidate_configs, compute_training_flops, compute_transformer_params, fit_scaling_laws, generate_isoflop_train_args, - isoflop_analysis_step, - isoflop_plots_step, predict_optimal_config, predict_optimal_configs_for_budgets, run_isoflop_analysis, run_isoflop_analysis_step, solve_for_batch_size, solve_for_train_steps, - upload_isoflop_plots_to_wandb_step, ) from marin.scaling_laws.tpu_utils import ( estimate_memory_bytes, pick_v5p_type, ) from marin.scaling_laws.recipe import ( - MARIN_2025_RECIPE, ScalingRecipe, ) from marin.scaling_laws.scaling_ladder import ( ModelBuilder, ScalingLadderRungConfig, - ScalingLadderSuite, default_model_builder, run_scaling_ladder_rung, - scaling_ladder_rung_step, - scaling_ladder_suite, ) from marin.scaling_laws.scaling_plots import ( create_isoflop_plot, @@ -66,24 +59,25 @@ ) __all__ = [ + # Constants "DEFAULT_BUDGETS", "DEFAULT_EVAL_METRIC_KEY", - "MARIN_2025_RECIPE", + "DEFAULT_FLOP_TOLERANCE", + "DEFAULT_SEQ_LEN", + "DEFAULT_STEPS_PER_RUN", + # Data classes "CandidateConfig", "FitScalingLawsResult", "IsoFlopAnalysisConfig", "IsoFlopAnalysisResult", - "IsoFlopPlotsConfig", - "IsoFlopSweepConfig", "IsoFlopTrainArgs", "MinimaRecord", "ModelBuilder", "QuadraticFitCoeffs", "ScalingFit", "ScalingLadderRungConfig", - "ScalingLadderSuite", "ScalingRecipe", - "UploadPlotsToWandbConfig", + # Functions "candidate_configs", "compute_training_flops", "compute_transformer_params", @@ -93,8 +87,6 @@ "estimate_memory_bytes", "fit_scaling_laws", "generate_isoflop_train_args", - "isoflop_analysis_step", - "isoflop_plots_step", "pick_v5p_type", "predict_optimal_config", "predict_optimal_configs_for_budgets", @@ -102,10 +94,7 @@ "run_isoflop_analysis_step", "run_scaling_ladder_rung", "save_plots", - "scaling_ladder_rung_step", - "scaling_ladder_suite", "solve_for_batch_size", "solve_for_train_steps", - "upload_isoflop_plots_to_wandb_step", "upload_plots_to_wandb", ] diff --git a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py index e95f189dbc..a8d8abef63 100644 --- a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py +++ b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py @@ -30,7 +30,6 @@ import fsspec import pandas as pd -from marin.execution.executor import InputName from marin.utilities.wandb_utils import WANDB_ENTITY, WANDB_PROJECT try: @@ -109,8 +108,8 @@ class EvalMetricsAnalysisConfig: The training_runs field creates blocking dependencies on the training jobs. """ - training_runs: Sequence[str | InputName] - """List of training run output paths (strings or InputNames resolved by the executor).""" + training_runs: Sequence[str] + """List of training run output paths (executor resolves InputName to str at runtime).""" output_path: str """Where to write analysis outputs.""" @@ -125,20 +124,6 @@ class EvalMetricsAnalysisConfig: """WandB entity/project to query for backfill (format: 'entity/project').""" -def _resolve_training_run_path(run_path: str | InputName) -> str: - """Return a concrete path string for a training run.""" - if isinstance(run_path, InputName): - if run_path.step is not None: - raise ValueError( - "Training runs must be concrete paths when calling read_metrics_dataframe directly. " - "Use the Executor to resolve InputName values." - ) - if run_path.name is None: - raise ValueError("Training run InputName must include a name.") - return run_path.name - return run_path - - def read_metrics_dataframe(config: EvalMetricsAnalysisConfig) -> pd.DataFrame: """ Read eval metrics from training runs into a DataFrame. @@ -155,7 +140,6 @@ def read_metrics_dataframe(config: EvalMetricsAnalysisConfig) -> pd.DataFrame: all_records = [] for i, run_path in enumerate(config.training_runs): - run_path = _resolve_training_run_path(run_path) metrics_file = os.path.join(run_path, config.metrics_filename) fs, _, _ = fsspec.get_fs_token_paths(metrics_file) diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index a159d10d88..c91f556cc8 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -14,30 +14,25 @@ """IsoFLOP analysis for finding compute-optimal training configurations. -Primary usage - create ExecutorSteps for your pipeline: +This module provides functions and configs for IsoFLOP scaling law analysis. +Experiments create ExecutorSteps directly using the provided functions. +Example usage in experiments: + + from marin.execution.executor import ExecutorStep, output_path_of from marin.scaling_laws import ( - isoflop_analysis_step, - isoflop_plots_step, - upload_isoflop_plots_to_wandb_step, + IsoFlopAnalysisConfig, + run_isoflop_analysis_step, ) - # Step 1: Compute metrics and fit scaling laws - analysis = isoflop_analysis_step( + # Create analysis step + analysis_step = ExecutorStep( name="my-scaling-analysis", - training_runs=my_training_steps, # list of ExecutorStep - ) - - # Step 2: Generate HTML plots (optional) - plots = isoflop_plots_step( - name="my-scaling-plots", - analysis_step=analysis, - ) - - # Step 3: Upload to WandB (optional) - upload = upload_isoflop_plots_to_wandb_step( - name="upload-scaling-plots", - analysis_step=analysis, + fn=run_isoflop_analysis_step, + config=IsoFlopAnalysisConfig( + training_runs=[output_path_of(r) for r in training_runs], + output_path="analysis/my-analysis", + ), ) The analysis step will: @@ -45,7 +40,7 @@ 2. Fit scaling laws to find compute-optimal token counts 3. Save results to JSON/parquet files -For programmatic use, see `run_isoflop_analysis()` which returns a `IsoFlopAnalysisResult`. +For programmatic use (without ExecutorStep), see `run_isoflop_analysis()`. """ import json @@ -54,7 +49,7 @@ import os import re from collections.abc import Iterator, Sequence -from dataclasses import asdict, dataclass, replace +from dataclasses import asdict, dataclass, field from typing import NamedTuple import fsspec @@ -63,20 +58,14 @@ from jaxopt import ScipyMinimize from levanter.models.llama import LlamaConfig -from levanter.optim.cautious import CautiousConfig from levanter.optim.config import OptimizerConfig -from marin.execution.executor import ExecutorStep, InputName, output_path_of, this_output_path from marin.scaling_laws.eval_metrics_reader import ( EvalMetricsAnalysisConfig, extract_run_name_from_path, read_metrics_dataframe, ) -from marin.scaling_laws.recipe import MARIN_2025_RECIPE, ScalingRecipe -from marin.scaling_laws.tpu_utils import ( - pick_v5p_type, -) -from marin.utilities.wandb_utils import WANDB_ENTITY, WANDB_PROJECT +from marin.scaling_laws.recipe import ScalingRecipe logger = logging.getLogger(__name__) @@ -140,45 +129,10 @@ class QuadraticFitCoeffs(NamedTuple): """Maximum token count used for fitting.""" -# ---------------- IsoFLOP Sweep Config ---------------- -@dataclass(frozen=True) -class IsoFlopSweepConfig: - """Configuration for generating ISOFlop sweep candidate configs. - - This config controls the FLOP budgets and training parameters. - Architecture decisions (num_layers formula, hidden_pow bounds, etc.) - are controlled by the ScalingRecipe. - """ - - recipe: ScalingRecipe = MARIN_2025_RECIPE - """Scaling recipe with all opinionated hyperparameters: - - Architecture formula (num_layers from hidden_size) - - Architecture ratios (mlp_ratio, hidden_head_ratio) - - Search bounds (min/max hidden_pow, step_size) - - Constraints (max_learning_rate, min_batch_size) - - Optimizer settings (weight_decay, warmup, etc.) - """ - - tokenizer: str = "stanford-crfm/marin-tokenizer" - """Tokenizer to use (needed for vocab size).""" - - budgets: tuple[float, ...] = DEFAULT_BUDGETS - """Tuple of FLOP budgets to generate configs for.""" - - seq_len: int = SEQ_LEN - """Sequence length for training.""" - - steps_per_run: int = 2**16 - """Number of training steps used for FLOP budget calculation and hyperparameter tuning. - - This is the reference step count that other hyperparameters (LR, beta2) are tuned for. - The actual training steps may differ based on batch size to hit the target FLOP budget. - Default of 2^16 (65,536) steps is used because the LR_CONSTANT and other tuned values - were optimized for this step count. - """ - - flop_tolerance: float = 0.01 - """Tolerance for matching FLOP budget (relative error).""" +# ---------------- IsoFLOP Sweep Defaults ---------------- +DEFAULT_SEQ_LEN = SEQ_LEN +DEFAULT_STEPS_PER_RUN = 2**16 # Reference step count for hyperparameter tuning +DEFAULT_FLOP_TOLERANCE = 0.01 # Relative error tolerance for FLOP budget # ---------------- Candidate Config ---------------- @@ -450,84 +404,36 @@ def compute_transformer_params( def candidate_configs( - cfg: IsoFlopSweepConfig, budget: float, vocab_size: int, + recipe: ScalingRecipe, + seq_len: int = DEFAULT_SEQ_LEN, + steps_per_run: int = DEFAULT_STEPS_PER_RUN, + flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, ) -> Iterator[CandidateConfig]: """Yield candidate model configurations within the FLOP budget. - This function uses the recipe for all opinionated choices: - - Architecture formula (num_layers from hidden_size) - - Architecture ratios (mlp_ratio, hidden_head_ratio) - - Search bounds (min/max hidden_pow, step_size) - - Constraints (max_learning_rate, min_batch_size) - - The mechanics layer (solve_for_batch_size, solve_for_train_steps, compute_training_flops) - handles the pure FLOP math. + This is a convenience function that delegates to recipe.candidate_configs(). + The recipe encapsulates all model-specific decisions (architecture formula, + search bounds, constraints), while this function provides backward compatibility. Args: - cfg: IsoFlopSweepConfig with recipe and other search parameters - budget: Target FLOP budget - vocab_size: Vocabulary size for the tokenizer + budget: Target FLOP budget. + vocab_size: Vocabulary size for the tokenizer. + recipe: ScalingRecipe with architecture/hyperparameter settings. + seq_len: Sequence length for training. + steps_per_run: Reference step count for FLOP budget calculation. + flop_tolerance: Tolerance for matching FLOP budget (relative error). Yields: - CandidateConfig objects for each valid configuration + CandidateConfig objects for each valid configuration. """ - recipe = cfg.recipe - - # RECIPE: Get search parameters - step_size = recipe.get_step_size(budget) - min_hidden = 2**recipe.min_hidden_pow - max_hidden = 2**recipe.max_hidden_pow - - for hidden_size in range(min_hidden, max_hidden + 1, step_size): - # RECIPE: Build model config (makes all architecture decisions) - model_config = recipe.build_model_config(hidden_size, cfg.seq_len) - - # MECHANICS: Solve for batch size to hit budget with target steps - batch_exact = solve_for_batch_size(model_config, vocab_size, budget, cfg.steps_per_run, cfg.seq_len) - batch_size = round_to_power_of_two(batch_exact) - - # RECIPE: Apply LR constraint - lr = recipe.compute_learning_rate(batch_size, hidden_size) - while lr > recipe.max_learning_rate: - batch_size //= 2 - lr = recipe.compute_learning_rate(batch_size, hidden_size) - - # RECIPE: Apply min batch constraint - if batch_size < recipe.min_batch_size: - continue - - # MECHANICS: Solve for steps to hit budget with chosen batch - train_steps = round(solve_for_train_steps(model_config, vocab_size, budget, batch_size, cfg.seq_len)) - - # MECHANICS: Verify we hit the budget within tolerance - achieved_flops = compute_training_flops(model_config, vocab_size, batch_size, train_steps, cfg.seq_len) - if abs(achieved_flops - budget) / budget > cfg.flop_tolerance: - continue - - # RECIPE: Compute optimizer hyperparameters - beta2 = recipe.compute_beta2(batch_size) - tokens = batch_size * train_steps * cfg.seq_len - - yield CandidateConfig( - hidden_size=hidden_size, - intermediate_dim=model_config.intermediate_dim, - num_layers=model_config.num_layers, - num_heads=model_config.num_heads, - num_kv_heads=model_config.num_kv_heads, - batch_size=batch_size, - train_steps=train_steps, - learning_rate=lr, - beta2=beta2, - tokens=tokens, - flops_budget=budget, - ) + yield from recipe.candidate_configs(budget, vocab_size, seq_len, steps_per_run, flop_tolerance) def _minima_to_candidates( minima_records: list[MinimaRecord], - recipe: ScalingRecipe = MARIN_2025_RECIPE, + recipe: ScalingRecipe, ) -> list[CandidateConfig]: """Convert minima records to CandidateConfig objects. @@ -564,121 +470,56 @@ def _minima_to_candidates( def generate_isoflop_train_args( - sweep_config: IsoFlopSweepConfig, + budgets: Sequence[float], experiment_name: str, vocab_size: int, + recipe: ScalingRecipe, + seq_len: int = DEFAULT_SEQ_LEN, + steps_per_run: int = DEFAULT_STEPS_PER_RUN, + flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, base_optimizer_config: OptimizerConfig | None = None, ) -> list[IsoFlopTrainArgs]: """Generate training arguments for each candidate in an isoflop sweep. - This function generates all the arguments needed to call default_train() for - each candidate configuration in the sweep. The caller is responsible for - constructing the experiment-specific SimpleTrainConfig. + This is a convenience function that delegates to recipe.generate_isoflop_train_args(). + The recipe encapsulates all model-specific decisions, while this function provides + backward compatibility. Args: - sweep_config: Configuration for the sweep (budgets, seq_len, etc.) - experiment_name: Name suffix for run names (e.g., 'nemo', 'dclm') - vocab_size: Vocabulary size for the tokenizer - base_optimizer_config: Base optimizer config to modify. If None, uses CautiousConfig defaults. + budgets: Sequence of FLOP budgets to generate configs for. + experiment_name: Name suffix for run names (e.g., 'nemo', 'dclm'). + vocab_size: Vocabulary size for the tokenizer. + recipe: ScalingRecipe with architecture/hyperparameter settings. + seq_len: Sequence length for training. + steps_per_run: Reference step count for FLOP budget calculation. + flop_tolerance: Tolerance for matching FLOP budget. + base_optimizer_config: Base optimizer config to modify. If None, uses recipe defaults. Returns: List of IsoFlopTrainArgs, one per candidate config across all budgets. Example: - >>> from marin.scaling_laws import IsoFlopSweepConfig, generate_isoflop_train_args - >>> from levanter.models.qwen import Qwen3Config - >>> config = IsoFlopSweepConfig(budgets=(1e18, 1e19)) - >>> train_args = generate_isoflop_train_args(config, "my-experiment", vocab_size=128256) + >>> from marin.scaling_laws import generate_isoflop_train_args, DEFAULT_BUDGETS + >>> train_args = generate_isoflop_train_args( + ... budgets=DEFAULT_BUDGETS, + ... experiment_name="my-experiment", + ... vocab_size=128256, + ... ) >>> for args in train_args: ... # Caller constructs the model config from candidate parameters ... model_config = Qwen3Config( ... hidden_dim=args.candidate.hidden_size, - ... intermediate_dim=args.candidate.intermediate_dim, - ... num_layers=args.candidate.num_layers, ... # ... etc ... ) - ... # Then use model_config with default_train() """ - recipe = sweep_config.recipe - if base_optimizer_config is None: - base_optimizer_config = CautiousConfig( - learning_rate=1.0, # Placeholder, will be overridden - weight_decay=recipe.weight_decay, - min_lr_ratio=recipe.min_lr_ratio, - warmup=recipe.warmup, - beta1=recipe.beta1, - beta2=0.98, # Placeholder, will be overridden - epsilon=recipe.epsilon, - max_grad_norm=recipe.max_grad_norm, - adamc_weight_decay=True, - lr_schedule=recipe.lr_schedule, - decay=recipe.decay, - ) - - results: list[IsoFlopTrainArgs] = [] - - for budget in sweep_config.budgets: - for candidate in candidate_configs(sweep_config, budget, vocab_size): - # Pick TPU type based on candidate parameters - tpu_type = pick_v5p_type(candidate, vocab_size, sweep_config.seq_len) - - # Build optimizer config with candidate-specific LR and beta2 - optimizer_cfg = replace( - base_optimizer_config, - learning_rate=candidate.learning_rate, - beta2=candidate.beta2, - ) - - # Generate run name and tags - run_name = ( - f"isoflop-{budget:.0e}-d{candidate.hidden_size}-" - f"L{candidate.num_layers}-B{candidate.batch_size}-{experiment_name}" - ) - - tags = ( - f"FLOPs={budget:.1e}", - f"d={candidate.hidden_size}", - f"L={candidate.num_layers}", - f"B={candidate.batch_size}", - f"steps={candidate.train_steps}", - f"tpu={tpu_type}", - ) - - # Static output path for checkpoint reuse - output_path = os.path.join("checkpoints", "isoflop", run_name) - - results.append( - IsoFlopTrainArgs( - candidate=candidate, - optimizer_config=optimizer_cfg, - tpu_type=tpu_type, - run_name=run_name, - tags=tags, - output_path=output_path, - ) - ) - - return results + return recipe.generate_isoflop_train_args( + budgets, experiment_name, vocab_size, seq_len, steps_per_run, flop_tolerance, base_optimizer_config + ) # ---------------- Helpers ---------------- -def _resolve_run_paths(runs: Sequence[ExecutorStep | InputName | str]) -> list[InputName | str]: - """Convert mixed ExecutorStep/InputName/path inputs to executor-ready paths. - - This helper reduces duplication across functions that accept either - ExecutorSteps or string paths. - - Args: - runs: Sequence of ExecutorStep, InputName, or path strings. - - Returns: - List of InputName or string paths. - """ - return [output_path_of(run) if isinstance(run, ExecutorStep) else run for run in runs] - - def parse_isoflop_run_name(run_name: str) -> str | None: """Parse experiment name from isoflop run name. @@ -933,11 +774,18 @@ def predict_optimal_config( scaling_fits: dict[str, ScalingFit], target_flops: float, label: str, - sweep_config: IsoFlopSweepConfig | None = None, - vocab_size: int = MARIN_TOKENIZER_VOCAB_SIZE, + vocab_size: int, + recipe: ScalingRecipe, + seq_len: int = DEFAULT_SEQ_LEN, + steps_per_run: int = DEFAULT_STEPS_PER_RUN, + flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, ) -> CandidateConfig | None: """Predict optimal training config for a target compute budget using fitted scaling laws. + This is a convenience function that delegates to recipe.predict_optimal_config(). + The recipe encapsulates all model-specific decisions, while this function provides + backward compatibility. + This implements IsoFLOP Approach 2 from the Chinchilla paper: 1. D_opt (optimal tokens) is found empirically at each compute budget by fitting parabolas to actual loss values and finding the minimum. @@ -945,60 +793,36 @@ def predict_optimal_config( 3. Given D_opt and C, N_opt (optimal params) is derived as C/(6D), so no separate alpha fit for params is needed. - This approach works regardless of whether the scaling exponents for params - and tokens are equal (alpha == beta), unlike Approach 3 which fits a - parametric loss surface. - - This function: - 1. Uses the scaling fit (N* ~ A * C^alpha) to predict optimal tokens for target_flops - 2. Generates candidate configs for the target budget using candidate_configs() - 3. Selects the candidate whose token count is closest to the predicted optimal - Args: scaling_fits: Dict of {label: ScalingFit} from scaling ladder result. target_flops: Target compute budget in FLOPs. label: Dataset/experiment label to use for scaling fit. - sweep_config: Optional IsoFlopSweepConfig. If None, uses defaults. - vocab_size: Vocabulary size (default: MARIN_TOKENIZER_VOCAB_SIZE for marin tokenizer). + vocab_size: Vocabulary size. + recipe: ScalingRecipe with architecture/hyperparameter settings. + seq_len: Sequence length for training. + steps_per_run: Reference step count for FLOP budget calculation. + flop_tolerance: Tolerance for matching FLOP budget. Returns: CandidateConfig for the predicted optimal, or None if label not in fits or no valid candidates found. """ - if label not in scaling_fits: - logger.warning(f"Label '{label}' not found in scaling fits") - return None - - alpha, A = scaling_fits[label] - optimal_tokens = A * (target_flops**alpha) - - logger.info(f"Predicted optimal tokens for {target_flops:.2e} FLOPs: {optimal_tokens:.2e}") - - if sweep_config is None: - sweep_config = IsoFlopSweepConfig() - - candidates = list(candidate_configs(sweep_config, target_flops, vocab_size)) - - if not candidates: - logger.warning(f"No valid candidates found for budget {target_flops:.2e}") - return None - - best = min(candidates, key=lambda c: abs(c.tokens - optimal_tokens)) - - logger.info( - f"Selected config: d={best.hidden_size}, L={best.num_layers}, " - f"B={best.batch_size}, tokens={best.tokens:.2e} (optimal: {optimal_tokens:.2e})" + # Convert ScalingFit NamedTuples to plain tuples for recipe method + fits_as_tuples = {k: (v.alpha, v.A) for k, v in scaling_fits.items()} + return recipe.predict_optimal_config( + fits_as_tuples, target_flops, label, vocab_size, seq_len, steps_per_run, flop_tolerance ) - return best - def predict_optimal_configs_for_budgets( scaling_fits: dict[str, ScalingFit], target_budgets: list[float], label: str, - sweep_config: IsoFlopSweepConfig | None = None, - vocab_size: int = MARIN_TOKENIZER_VOCAB_SIZE, + vocab_size: int, + recipe: ScalingRecipe, + seq_len: int = DEFAULT_SEQ_LEN, + steps_per_run: int = DEFAULT_STEPS_PER_RUN, + flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, ) -> list[CandidateConfig]: """Predict optimal configs for multiple target compute budgets. @@ -1006,8 +830,11 @@ def predict_optimal_configs_for_budgets( scaling_fits: Dict of {label: ScalingFit} from scaling ladder result. target_budgets: List of target compute budgets in FLOPs. label: Dataset/experiment label to use for scaling fit. - sweep_config: Optional IsoFlopSweepConfig. If None, uses defaults. vocab_size: Vocabulary size. + recipe: ScalingRecipe with architecture/hyperparameter settings. + seq_len: Sequence length for training. + steps_per_run: Reference step count for FLOP budget calculation. + flop_tolerance: Tolerance for matching FLOP budget. Returns: List of CandidateConfig for each budget. @@ -1017,7 +844,9 @@ def predict_optimal_configs_for_budgets( """ configs = [] for budget in target_budgets: - config = predict_optimal_config(scaling_fits, budget, label, sweep_config, vocab_size) + config = predict_optimal_config( + scaling_fits, budget, label, vocab_size, recipe, seq_len, steps_per_run, flop_tolerance + ) if config is None: raise RuntimeError( f"Failed to predict optimal config for budget {budget:.2e} FLOPs " @@ -1059,13 +888,6 @@ def to_json_dict(self) -> dict: } -def _parse_fit_curve_coeffs(coeffs: Sequence[float]) -> QuadraticFitCoeffs: - if len(coeffs) != 5: - raise ValueError(f"Expected 5 fit curve coefficients, got {len(coeffs)}") - a, b, c, token_min, token_max = coeffs - return QuadraticFitCoeffs(float(a), float(b), float(c), float(token_min), float(token_max)) - - # ---------------- ExecutorStep Config ---------------- @@ -1073,44 +895,16 @@ def _parse_fit_curve_coeffs(coeffs: Sequence[float]) -> QuadraticFitCoeffs: class IsoFlopAnalysisConfig(EvalMetricsAnalysisConfig): """Configuration for scaling ladder analysis ExecutorStep.""" - recipe: ScalingRecipe = MARIN_2025_RECIPE + recipe: ScalingRecipe = field(kw_only=True) """Scaling recipe for computing optimal hyperparameters.""" - metric_key: str = DEFAULT_EVAL_METRIC_KEY + metric_key: str = field(default=DEFAULT_EVAL_METRIC_KEY, kw_only=True) """Metric to use for loss (default: eval/paloma/c4_en/bpb - Paloma benchmark on C4 English).""" - label_map: tuple[tuple[str, str], ...] | None = None + label_map: tuple[tuple[str, str], ...] | None = field(default=None, kw_only=True) """Optional mapping from experiment_name -> display label as tuple of pairs.""" -@dataclass(frozen=True) -class IsoFlopPlotsConfig: - """Configuration for isoflop plots ExecutorStep.""" - - analysis_output_path: str - """Path to the isoflop analysis output (containing isoflop_analysis_result.json).""" - - output_path: str - """Path to save the HTML plots.""" - - -@dataclass(frozen=True) -class UploadPlotsToWandbConfig: - """Configuration for uploading plots to WandB.""" - - plots_path: str - """Path to the directory containing HTML plots.""" - - wandb_entity: str = WANDB_ENTITY - """WandB entity for uploads.""" - - wandb_project: str = f"{WANDB_PROJECT}-analysis" - """WandB project for uploads.""" - - wandb_run_name: str = "scaling-ladder-analysis" - """Name for the WandB run.""" - - def run_isoflop_analysis_step(config: IsoFlopAnalysisConfig) -> None: """Execute scaling ladder analysis (called by ExecutorStep).""" raw_df = read_metrics_dataframe(config) @@ -1167,242 +961,23 @@ def run_isoflop_analysis_step(config: IsoFlopAnalysisConfig) -> None: logger.info(f"Saved fit curves to {fit_curves_path}") -def _run_isoflop_plots_step(config: IsoFlopPlotsConfig) -> None: - """Generate and save isoflop plots (called by ExecutorStep).""" - from marin.scaling_laws.scaling_plots import ( - create_isoflop_plot, - create_scaling_plot, - save_plots, - ) - - fs, _, _ = fsspec.get_fs_token_paths(config.analysis_output_path) - - # Load the analysis results - result_path = os.path.join(config.analysis_output_path, "isoflop_analysis_result.json") - with fs.open(result_path, "r") as f: - result_dict = json.load(f) - - # Load the dataframe - df_path = os.path.join(config.analysis_output_path, "isoflop_df.parquet") - isoflop_df = pd.read_parquet(df_path) - - # Load fit curves and reconstruct tuple keys - fit_curves_path = os.path.join(config.analysis_output_path, "fit_curves.json") - with fs.open(fit_curves_path, "r") as f: - fit_curves_json = json.load(f) - fit_curves: dict[tuple[str, float], QuadraticFitCoeffs] = {} - for key_str, coeffs in fit_curves_json.items(): - label, flops = key_str.rsplit("|", 1) - fit_curves[(label, float(flops))] = _parse_fit_curve_coeffs(coeffs) - - # Reconstruct minima records - minima_records = [MinimaRecord(**r) for r in result_dict["minima_records"]] - scaling_fits = {k: ScalingFit(*v) for k, v in result_dict["scaling_fits"].items()} - - # Create plots - fig_isoflop = create_isoflop_plot(isoflop_df, minima_records, fit_curves) - fig_scaling = create_scaling_plot(minima_records, scaling_fits) - - # Save plots - save_plots(fig_isoflop, fig_scaling, config.output_path) - - -def _run_upload_plots_to_wandb_step(config: UploadPlotsToWandbConfig) -> None: - """Upload plots to WandB (called by ExecutorStep).""" - from marin.scaling_laws.scaling_plots import ( - create_isoflop_plot, - create_scaling_plot, - upload_plots_to_wandb, - ) - - fs, _, _ = fsspec.get_fs_token_paths(config.plots_path) - - # Load the analysis results to regenerate plots - result_path = os.path.join(config.plots_path, "isoflop_analysis_result.json") - with fs.open(result_path, "r") as f: - result_dict = json.load(f) - - # Load the dataframe - df_path = os.path.join(config.plots_path, "isoflop_df.parquet") - isoflop_df = pd.read_parquet(df_path) - - # Load fit curves and reconstruct tuple keys - fit_curves_path = os.path.join(config.plots_path, "fit_curves.json") - with fs.open(fit_curves_path, "r") as f: - fit_curves_json = json.load(f) - fit_curves: dict[tuple[str, float], QuadraticFitCoeffs] = {} - for key_str, coeffs in fit_curves_json.items(): - label, flops = key_str.rsplit("|", 1) - fit_curves[(label, float(flops))] = _parse_fit_curve_coeffs(coeffs) - - # Reconstruct minima records - minima_records = [MinimaRecord(**r) for r in result_dict["minima_records"]] - scaling_fits = {k: ScalingFit(*v) for k, v in result_dict["scaling_fits"].items()} - - # Create plots - fig_isoflop = create_isoflop_plot(isoflop_df, minima_records, fit_curves) - fig_scaling = create_scaling_plot(minima_records, scaling_fits) - - upload_plots_to_wandb( - fig_isoflop, - fig_scaling, - entity=config.wandb_entity, - project=config.wandb_project, - run_name=config.wandb_run_name, - ) - - -# ---------------- Primary Export: ExecutorStep Factory ---------------- - - -def isoflop_analysis_step( - name: str, - training_runs: Sequence[ExecutorStep | InputName], - metric_key: str = DEFAULT_EVAL_METRIC_KEY, - label_map: dict[str, str] | None = None, - recipe: ScalingRecipe = MARIN_2025_RECIPE, -) -> ExecutorStep: - """Create an ExecutorStep for scaling ladder analysis. - - This step computes scaling law fits and saves results to JSON/parquet files. - For plotting, use `isoflop_plots_step()`. For WandB upload, use - `upload_isoflop_plots_to_wandb_step()`. - - Args: - name: Name for this executor step - training_runs: Training run ExecutorSteps or InputNames to analyze - metric_key: Which metric to use for loss (default: eval/paloma/c4_en/bpb) - label_map: Optional mapping from experiment_name -> display label - recipe: ScalingRecipe with hyperparameters - - Returns: - ExecutorStep configured to run the analysis - - Example: - >>> from marin.scaling_laws import isoflop_analysis_step, isoflop_plots_step - >>> analysis = isoflop_analysis_step( - ... name="my-scaling-analysis", - ... training_runs=my_training_steps, - ... ) - >>> plots = isoflop_plots_step( - ... name="my-scaling-plots", - ... analysis_step=analysis, - ... ) - """ - run_paths = _resolve_run_paths(training_runs) - - config = IsoFlopAnalysisConfig( - training_runs=run_paths, - output_path=this_output_path(), - recipe=recipe, - metric_key=metric_key, - label_map=tuple(label_map.items()) if label_map else None, - ) - - return ExecutorStep( - name=name, - fn=run_isoflop_analysis_step, - config=config, - description=f"Scaling ladder analysis for {len(training_runs)} training runs", - ) - - -def isoflop_plots_step( - name: str, - analysis_step: ExecutorStep | InputName, -) -> ExecutorStep: - """Create an ExecutorStep to generate isoflop HTML plots. - - This step reads the output from an isoflop_analysis_step and generates - HTML plots for the isoflop curves and scaling fits. - - Args: - name: Name for this executor step - analysis_step: The isoflop_analysis_step to read results from - - Returns: - ExecutorStep configured to generate plots - - Example: - >>> analysis = isoflop_analysis_step(name="analysis", training_runs=runs) - >>> plots = isoflop_plots_step(name="plots", analysis_step=analysis) - """ - analysis_path = output_path_of(analysis_step) if isinstance(analysis_step, ExecutorStep) else analysis_step - - config = IsoFlopPlotsConfig( - analysis_output_path=analysis_path, - output_path=this_output_path(), - ) - - return ExecutorStep( - name=name, - fn=_run_isoflop_plots_step, - config=config, - description="Generate isoflop HTML plots", - ) - - -def upload_isoflop_plots_to_wandb_step( - name: str, - analysis_step: ExecutorStep | InputName, - wandb_entity: str = WANDB_ENTITY, - wandb_project: str = f"{WANDB_PROJECT}-analysis", - wandb_run_name: str | None = None, -) -> ExecutorStep: - """Create an ExecutorStep to upload isoflop plots to WandB. - - This step reads the analysis results and uploads interactive plots to WandB. - - Args: - name: Name for this executor step - analysis_step: The isoflop_analysis_step to read results from - wandb_entity: WandB entity for uploads - wandb_project: WandB project for uploads - wandb_run_name: Name for WandB run (defaults to step name) - - Returns: - ExecutorStep configured to upload plots to WandB - - Example: - >>> analysis = isoflop_analysis_step(name="analysis", training_runs=runs) - >>> upload = upload_isoflop_plots_to_wandb_step( - ... name="upload-plots", - ... analysis_step=analysis, - ... ) - """ - analysis_path = output_path_of(analysis_step) if isinstance(analysis_step, ExecutorStep) else analysis_step - - config = UploadPlotsToWandbConfig( - plots_path=analysis_path, - wandb_entity=wandb_entity, - wandb_project=wandb_project, - wandb_run_name=wandb_run_name or name, - ) - - return ExecutorStep( - name=name, - fn=_run_upload_plots_to_wandb_step, - config=config, - description="Upload isoflop plots to WandB", - ) - - # ---------------- Programmatic Interface ---------------- def run_isoflop_analysis( - training_runs: Sequence[ExecutorStep] | Sequence[str], + training_runs: Sequence[str], + recipe: ScalingRecipe, metric_key: str = DEFAULT_EVAL_METRIC_KEY, label_map: dict[str, str] | None = None, - recipe: ScalingRecipe = MARIN_2025_RECIPE, ) -> IsoFlopAnalysisResult: """Analyze isoflop training runs and return optimal training configurations. - This is the programmatic interface for scaling ladder analysis. For pipeline - usage, prefer `isoflop_analysis_step()` which returns an ExecutorStep. + This is the programmatic interface for scaling ladder analysis, useful for + notebooks or scripts. For ExecutorStep-based pipelines, use + `run_isoflop_analysis_step()` with `IsoFlopAnalysisConfig`. Args: - training_runs: List of ExecutorSteps or path strings to training runs + training_runs: List of path strings to training run output directories metric_key: Which metric to use for loss (default: eval/paloma/c4_en/bpb) label_map: Optional mapping from experiment_name -> display label recipe: ScalingRecipe with hyperparameter settings @@ -1410,10 +985,8 @@ def run_isoflop_analysis( Returns: IsoFlopAnalysisResult with configs, scaling_fits, and analysis data """ - run_paths = _resolve_run_paths(training_runs) - config = EvalMetricsAnalysisConfig( - training_runs=run_paths, + training_runs=training_runs, output_path="analysis/scaling_ladder", ) raw_df = read_metrics_dataframe(config) diff --git a/lib/marin/src/marin/scaling_laws/recipe.py b/lib/marin/src/marin/scaling_laws/recipe.py index 560bcbe685..6a290fcc8f 100644 --- a/lib/marin/src/marin/scaling_laws/recipe.py +++ b/lib/marin/src/marin/scaling_laws/recipe.py @@ -12,53 +12,82 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Scaling recipes: named hyperparameter bundles for scaling law experiments. +"""Scaling recipes: model-specific hyperparameter bundles for scaling law experiments. -A recipe makes "opinionated defaults" explicit and named, so users consciously -choose which set of hyperparameters to use rather than getting hidden defaults. - -The recipe controls: +A ScalingRecipe encapsulates ALL model-specific decisions for scaling experiments: - Architecture formula (how to compute num_layers from hidden_size) - Architecture ratios (MLP width, head size) +- Model config building (returns LlamaConfig or subclass) - Learning rate and optimizer hyperparameters - Search bounds and constraints for isoflop sweeps +- Candidate generation for isoflop sweeps + +The scaling_laws module provides model-agnostic utilities (FLOP math, scaling law +fitting), while the recipe provides the model-specific "driver" that uses them. + +Experiments should define their own recipe instances: + + from marin.scaling_laws import ScalingRecipe + + MY_RECIPE = ScalingRecipe(name="my-experiment") + + # Generate candidates for a FLOP budget + for candidate in MY_RECIPE.candidate_configs(budget=1e18, vocab_size=128256): + print(candidate) -Usage: - from marin.scaling_laws.recipe import MARIN_2025_RECIPE, ScalingRecipe - - # Use the default recipe - recipe = MARIN_2025_RECIPE - model_config = recipe.build_model_config(hidden_size=1024, seq_len=4096) - lr = recipe.compute_learning_rate(batch_size=256, hidden_dim=1024) - beta2 = recipe.compute_beta2(batch_size=256) - - # Or create a custom recipe with different architecture formula - my_recipe = ScalingRecipe( - name="my-experiment", - lr_constant=0.25, - base_hidden_layer_ratio=48, # shallower models - ) + # Build model and optimizer configs + model_config = MY_RECIPE.build_model_config(hidden_size=1024) + optimizer_config = MY_RECIPE.build_optimizer_config(lr=0.001, beta2=0.95) """ import math -from dataclasses import dataclass +import os +from collections.abc import Iterator, Sequence +from dataclasses import dataclass, replace +from typing import TYPE_CHECKING from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig +from levanter.models.llama import LlamaConfig from levanter.models.qwen import Qwen3Config from levanter.optim.cautious import CautiousConfig +from levanter.optim.config import OptimizerConfig + +if TYPE_CHECKING: + from marin.scaling_laws.isoflop_analysis import CandidateConfig, IsoFlopTrainArgs + +# TODO: LlamaConfig is used as our "abstract" model config base class. +# All model configs we use (Qwen3Config, etc.) inherit from LlamaConfig +# and provide flops_per_token() for FLOP calculations. + +# Default constants +DEFAULT_SEQ_LEN = 4096 +DEFAULT_STEPS_PER_RUN = 2**16 # Reference step count for hyperparameter tuning +DEFAULT_FLOP_TOLERANCE = 0.01 # Relative error tolerance for FLOP budget +DEFAULT_TOKENIZER = "stanford-crfm/marin-tokenizer" + + +def _round_to_power_of_two(x: float) -> int: + """Round to the nearest power of 2.""" + if x <= 0: + return 1 + log2_x = math.log2(x) + lower = 2 ** int(log2_x) + upper = 2 ** (int(log2_x) + 1) + return lower if (x - lower) < (upper - x) else upper @dataclass(frozen=True) class ScalingRecipe: """A named set of hyperparameters for scaling law experiments. - The recipe controls: + The recipe encapsulates ALL model-specific decisions: - Architecture formula (num_layers from hidden_size) - Architecture ratios (MLP width, head size) - Learning rate scaling formula - Beta2 scaling formula (for Adam) - Optimizer hyperparameters (weight decay, warmup, etc.) - Search bounds and constraints for isoflop sweeps + - Candidate generation """ name: str @@ -136,6 +165,8 @@ class ScalingRecipe: budget_step_threshold: float = 9e18 """Budget threshold for switching step sizes.""" + # --- Hyperparameter formulas --- + def compute_learning_rate(self, batch_size: int, hidden_dim: int) -> float: """Compute learning rate from batch size and hidden dim.""" return (self.lr_constant * math.sqrt(batch_size)) / hidden_dim @@ -145,10 +176,7 @@ def compute_beta2(self, batch_size: int) -> float: return self.beta2_base ** (batch_size / self.beta2_batch_divisor) def compute_num_layers(self, hidden_size: int) -> int: - """Compute number of layers from hidden size using the depth-width formula. - - This is an opinionated formula for balancing model depth and width. - """ + """Compute number of layers from hidden size using the depth-width formula.""" hs_pow = math.log2(hidden_size) return round( hidden_size @@ -161,19 +189,20 @@ def get_step_size(self, budget: float) -> int: return self.large_budget_step_size return self.small_budget_step_size - def build_model_config(self, hidden_size: int, seq_len: int = 4096) -> Qwen3Config: - """Build a model config from hidden_size using this recipe's architecture formula. + # --- Model config building --- - This is the key interface: the recipe makes all architecture decisions - and returns a fully-specified model config. + def build_model_config(self, hidden_size: int, seq_len: int = DEFAULT_SEQ_LEN) -> LlamaConfig: + """Build a model config from hidden_size using this recipe's architecture formula. Args: hidden_size: Model hidden dimension. seq_len: Maximum sequence length. Returns: - A Qwen3Config with architecture determined by this recipe. + A LlamaConfig (or subclass) with architecture determined by this recipe. """ + # TODO: Currently returns Qwen3Config which inherits from LlamaConfig. + # This could be parameterized to return different model types. num_layers = self.compute_num_layers(hidden_size) intermediate_dim = hidden_size * self.mlp_ratio n_heads = max(1, hidden_size // self.hidden_head_ratio) @@ -188,18 +217,15 @@ def build_model_config(self, hidden_size: int, seq_len: int = 4096) -> Qwen3Conf rope=Llama3RotaryEmbeddingsConfig(), ) - def build_optimizer_config(self, learning_rate: float, beta2: float) -> CautiousConfig: + def build_optimizer_config(self, learning_rate: float, beta2: float) -> OptimizerConfig: """Build optimizer config using this recipe's hyperparameters. - This centralizes all optimizer configuration in the recipe, ensuring - consistent hyperparameters across isoflop sweeps and optimal training runs. - Args: - learning_rate: Learning rate (typically from CandidateConfig). - beta2: Adam beta2 (typically from CandidateConfig). + learning_rate: Learning rate. + beta2: Adam beta2. Returns: - A CautiousConfig with optimizer settings from this recipe. + An OptimizerConfig with settings from this recipe. """ return CautiousConfig( learning_rate=learning_rate, @@ -215,7 +241,235 @@ def build_optimizer_config(self, learning_rate: float, beta2: float) -> Cautious decay=self.decay, ) + # --- Candidate generation (model-specific search) --- + + def candidate_configs( + self, + budget: float, + vocab_size: int, + seq_len: int = DEFAULT_SEQ_LEN, + steps_per_run: int = DEFAULT_STEPS_PER_RUN, + flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, + ) -> "Iterator[CandidateConfig]": + """Yield candidate model configurations within the FLOP budget. + + This method encapsulates the model-specific search logic: + - Iterates over hidden_sizes using the recipe's search bounds + - Builds model configs using the recipe's architecture formula + - Applies the recipe's LR and batch size constraints + + Args: + budget: Target FLOP budget. + vocab_size: Vocabulary size for the tokenizer. + seq_len: Sequence length for training. + steps_per_run: Reference step count for FLOP budget calculation. + flop_tolerance: Tolerance for matching FLOP budget (relative error). + + Yields: + CandidateConfig objects for each valid configuration. + """ + # Import here to avoid circular dependency + from marin.scaling_laws.isoflop_analysis import CandidateConfig, solve_for_batch_size, solve_for_train_steps + + step_size = self.get_step_size(budget) + min_hidden = 2**self.min_hidden_pow + max_hidden = 2**self.max_hidden_pow + + for hidden_size in range(min_hidden, max_hidden + 1, step_size): + # Build model config using recipe's architecture formula + model_config = self.build_model_config(hidden_size, seq_len) + + # Use model-agnostic FLOP utilities to solve for batch/steps + # TODO: LlamaConfig.flops_per_token() provides the FLOP calculation + batch_exact = solve_for_batch_size(model_config, vocab_size, budget, steps_per_run, seq_len) + batch_size = _round_to_power_of_two(batch_exact) + + # Apply LR constraint (recipe-specific) + lr = self.compute_learning_rate(batch_size, hidden_size) + while lr > self.max_learning_rate: + batch_size //= 2 + lr = self.compute_learning_rate(batch_size, hidden_size) + + # Apply min batch constraint (recipe-specific) + if batch_size < self.min_batch_size: + continue + + # Solve for steps to hit budget with chosen batch + train_steps = round(solve_for_train_steps(model_config, vocab_size, budget, batch_size, seq_len)) + + # Verify we hit the budget within tolerance + # Training FLOPs = 3 * flops_per_token * batch * steps * seq_len + # The 3x multiplier accounts for forward (1x) + backward (2x) pass + achieved_flops = 3 * model_config.flops_per_token(vocab_size, seq_len) * batch_size * train_steps * seq_len + if abs(achieved_flops - budget) / budget > flop_tolerance: + continue + + # Compute optimizer hyperparameters (recipe-specific) + beta2 = self.compute_beta2(batch_size) + tokens = batch_size * train_steps * seq_len + + yield CandidateConfig( + hidden_size=hidden_size, + intermediate_dim=model_config.intermediate_dim, + num_layers=model_config.num_layers, + num_heads=model_config.num_heads, + num_kv_heads=model_config.num_kv_heads, + batch_size=batch_size, + train_steps=train_steps, + learning_rate=lr, + beta2=beta2, + tokens=tokens, + flops_budget=budget, + ) + + def generate_isoflop_train_args( + self, + budgets: Sequence[float], + experiment_name: str, + vocab_size: int, + seq_len: int = DEFAULT_SEQ_LEN, + steps_per_run: int = DEFAULT_STEPS_PER_RUN, + flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, + base_optimizer_config: OptimizerConfig | None = None, + ) -> "list[IsoFlopTrainArgs]": + """Generate training arguments for each candidate in an isoflop sweep. + + This method generates all the arguments needed to set up training runs for + each candidate configuration in the sweep. + + Args: + budgets: Sequence of FLOP budgets to generate configs for. + experiment_name: Name suffix for run names (e.g., 'nemo', 'dclm'). + vocab_size: Vocabulary size for the tokenizer. + seq_len: Sequence length for training. + steps_per_run: Reference step count for FLOP budget calculation. + flop_tolerance: Tolerance for matching FLOP budget. + base_optimizer_config: Base optimizer config to modify. If None, uses recipe defaults. + + Returns: + List of IsoFlopTrainArgs, one per candidate config across all budgets. + """ + # Import here to avoid circular dependency + from marin.scaling_laws.isoflop_analysis import IsoFlopTrainArgs + from marin.scaling_laws.tpu_utils import pick_v5p_type + + if base_optimizer_config is None: + base_optimizer_config = CautiousConfig( + learning_rate=1.0, # Placeholder, will be overridden + weight_decay=self.weight_decay, + min_lr_ratio=self.min_lr_ratio, + warmup=self.warmup, + beta1=self.beta1, + beta2=0.98, # Placeholder, will be overridden + epsilon=self.epsilon, + max_grad_norm=self.max_grad_norm, + adamc_weight_decay=True, + lr_schedule=self.lr_schedule, + decay=self.decay, + ) + + results: list[IsoFlopTrainArgs] = [] + + for budget in budgets: + for candidate in self.candidate_configs(budget, vocab_size, seq_len, steps_per_run, flop_tolerance): + # Pick TPU type based on candidate parameters + tpu_type = pick_v5p_type(candidate, vocab_size, seq_len) + + # Build optimizer config with candidate-specific LR and beta2 + optimizer_cfg = replace( + base_optimizer_config, + learning_rate=candidate.learning_rate, + beta2=candidate.beta2, + ) + + # Generate run name and tags + run_name = ( + f"isoflop-{budget:.0e}-d{candidate.hidden_size}-" + f"L{candidate.num_layers}-B{candidate.batch_size}-{experiment_name}" + ) + + tags = ( + f"FLOPs={budget:.1e}", + f"d={candidate.hidden_size}", + f"L={candidate.num_layers}", + f"B={candidate.batch_size}", + f"steps={candidate.train_steps}", + f"tpu={tpu_type}", + ) + + # Static output path for checkpoint reuse + output_path = os.path.join("checkpoints", "isoflop", run_name) + + results.append( + IsoFlopTrainArgs( + candidate=candidate, + optimizer_config=optimizer_cfg, + tpu_type=tpu_type, + run_name=run_name, + tags=tags, + output_path=output_path, + ) + ) + + return results + + def predict_optimal_config( + self, + scaling_fits: "dict[str, tuple[float, float]]", + target_flops: float, + label: str, + vocab_size: int, + seq_len: int = DEFAULT_SEQ_LEN, + steps_per_run: int = DEFAULT_STEPS_PER_RUN, + flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, + ) -> "CandidateConfig | None": + """Predict optimal training config for a target compute budget using fitted scaling laws. + + This implements IsoFLOP Approach 2 from the Chinchilla paper: + 1. Uses the scaling fit (N* ~ A * C^alpha) to predict optimal tokens for target_flops + 2. Generates candidate configs for the target budget using this recipe + 3. Selects the candidate whose token count is closest to the predicted optimal + + Args: + scaling_fits: Dict of {label: (alpha, A)} from scaling ladder result. + target_flops: Target compute budget in FLOPs. + label: Dataset/experiment label to use for scaling fit. + vocab_size: Vocabulary size. + seq_len: Sequence length for training. + steps_per_run: Reference step count for FLOP budget calculation. + flop_tolerance: Tolerance for matching FLOP budget. + + Returns: + CandidateConfig for the predicted optimal, or None if label not in fits + or no valid candidates found. + """ + import logging + + logger = logging.getLogger(__name__) + + if label not in scaling_fits: + logger.warning(f"Label '{label}' not found in scaling fits") + return None + + alpha, A = scaling_fits[label] + optimal_tokens = A * (target_flops**alpha) + + logger.info(f"Predicted optimal tokens for {target_flops:.2e} FLOPs: {optimal_tokens:.2e}") + + candidates = list(self.candidate_configs(target_flops, vocab_size, seq_len, steps_per_run, flop_tolerance)) + + if not candidates: + logger.warning(f"No valid candidates found for budget {target_flops:.2e}") + return None + + best = min(candidates, key=lambda c: c.tokens - optimal_tokens if c.tokens >= optimal_tokens else float("inf")) + # If all candidates have fewer tokens than optimal, pick the one with the most tokens + if best.tokens < optimal_tokens: + best = max(candidates, key=lambda c: c.tokens) + + logger.info( + f"Selected config: d={best.hidden_size}, L={best.num_layers}, " + f"B={best.batch_size}, tokens={best.tokens:.2e} (optimal: {optimal_tokens:.2e})" + ) -# Named recipes -MARIN_2025_RECIPE = ScalingRecipe(name="marin-2025") -"""Default Marin scaling recipe based on 2025 best practices.""" + return best diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py index 3457e6499d..28f5d2c889 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -14,32 +14,36 @@ """Scaling ladder: compute-optimal training runs based on IsoFLOP analysis. -This module provides ExecutorSteps for training models with compute-optimal -configurations derived from IsoFLOP analysis. +This module provides functions and configs for training models with compute-optimal +configurations derived from IsoFLOP analysis. Experiments create ExecutorSteps +directly using the provided functions. -Usage: - from marin.scaling_laws import isoflop_analysis_step, scaling_ladder_rung_step +Example usage in experiments: - # First, run IsoFLOP analysis - analysis = isoflop_analysis_step( - name="scaling-analysis", - training_runs=isoflop_training_steps, + from marin.execution.executor import ExecutorStep, output_path_of + from marin.scaling_laws import ( + ScalingLadderRungConfig, + run_scaling_ladder_rung, ) - # Then create optimal training steps (ladder rungs) that depend on the analysis - rung_1e21 = scaling_ladder_rung_step( + # Create optimal training step that depends on analysis output + optimal_step = ExecutorStep( name="optimal-1e21", - analysis_step=analysis, - target_budget=1e21, - label="nemo", - tokenized=my_tokenized_dataset, + fn=run_scaling_ladder_rung, + config=ScalingLadderRungConfig( + analysis_output_path=output_path_of(analysis_step), + target_budget=1e21, + label="nemo", + tokenized=my_tokenized_dataset, + output_path="checkpoints/optimal-1e21", + ), ) """ import json import logging import os -from collections.abc import Callable, Sequence +from collections.abc import Callable from dataclasses import dataclass from datetime import timedelta @@ -57,26 +61,19 @@ from levanter.trainer import TrainerConfig from levanter.utils.mesh import MeshConfig -from marin.execution.executor import ExecutorStep, InputName, output_path_of, this_output_path from marin.processing.tokenize import get_vocab_size_for_tokenizer from marin.processing.tokenize.data_configs import add_validation_sets_to_mixture, lm_data_config -from marin.processing.tokenize.tokenize import TokenizeConfig from marin.scaling_laws.isoflop_analysis import ( CandidateConfig, - IsoFlopSweepConfig, ScalingFit, - isoflop_analysis_step, - pick_v5p_type, predict_optimal_config, ) -from marin.scaling_laws.recipe import MARIN_2025_RECIPE, ScalingRecipe +from marin.scaling_laws.tpu_utils import pick_v5p_type +from marin.scaling_laws.recipe import ScalingRecipe from marin.training.training import TrainLmOnPodConfig, run_levanter_train_lm logger = logging.getLogger(__name__) -# Type alias for tokenizer steps -TokenizerStep = ExecutorStep[TokenizeConfig] - # Type alias for model builder callbacks # Takes (candidate, seq_len) and returns a model config ModelBuilder = Callable[[CandidateConfig, int], LmConfig] @@ -100,8 +97,8 @@ def default_model_builder(candidate: CandidateConfig, seq_len: int) -> Qwen3Conf def _prepare_data_config( - tokenized: InputName | str | LMMixtureDatasetConfig, - validation_sets: dict[str, TokenizerStep] | None = None, + tokenized: str | LMMixtureDatasetConfig, + validation_sets: dict | None = None, ) -> LMMixtureDatasetConfig: """Prepare a tokenized dataset for training. @@ -110,8 +107,8 @@ def _prepare_data_config( explicitly if needed. Args: - tokenized: The tokenized dataset - can be an InputName, path string, - or an already-configured LMMixtureDatasetConfig. + tokenized: The tokenized dataset - can be a path string or an + already-configured LMMixtureDatasetConfig. validation_sets: Optional dict of validation sets to add. If None, no validation sets are added. @@ -123,7 +120,7 @@ def _prepare_data_config( if validation_sets: pretraining_data = add_validation_sets_to_mixture(pretraining_data, validation_sets) else: - # InputName or string path + # String path pretraining_data = lm_data_config( training_set=tokenized, validation_sets=validation_sets, @@ -136,7 +133,7 @@ def _prepare_data_config( class ScalingLadderRungConfig: """Configuration for one rung of the scaling ladder (one compute-optimal training run). - This config references an IsoFLOP analysis step and specifies + This config references an IsoFLOP analysis output and specifies the target compute budget. At runtime, the optimal config is loaded from the analysis output. """ @@ -150,28 +147,25 @@ class ScalingLadderRungConfig: label: str """Dataset label to use for scaling fit (e.g., 'nemo', 'comma', 'dclm').""" - tokenized: InputName | str | LMMixtureDatasetConfig - """Tokenized dataset for training. Can be a path, InputName, or LMMixtureDatasetConfig.""" + tokenized: str | LMMixtureDatasetConfig + """Tokenized dataset for training. Can be a path string or LMMixtureDatasetConfig.""" output_path: str """Where to write training outputs.""" + recipe: ScalingRecipe + """Scaling recipe with hyperparameters.""" + model_builder: ModelBuilder | None = None """Function to build model config from CandidateConfig. If None, uses default_model_builder (Qwen3).""" - recipe: ScalingRecipe = MARIN_2025_RECIPE - """Scaling recipe with hyperparameters.""" - tokenizer: str = "stanford-crfm/marin-tokenizer" """Tokenizer to use.""" seq_len: int = 4096 """Sequence length for training.""" - sweep_config: IsoFlopSweepConfig | None = None - """Optional sweep config for predict_optimal_config. Uses defaults if None.""" - - validation_sets: dict[str, TokenizerStep] | None = None + validation_sets: dict | None = None """Optional validation sets to add for eval loss tracking.""" @@ -195,8 +189,9 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: scaling_fits=scaling_fits, target_flops=config.target_budget, label=config.label, - sweep_config=config.sweep_config, vocab_size=vocab_size, + recipe=config.recipe, + seq_len=config.seq_len, ) if candidate is None: @@ -262,179 +257,3 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: ) run_levanter_train_lm(full_config) - - -def scaling_ladder_rung_step( - name: str, - analysis_step: ExecutorStep, - target_budget: float, - label: str, - tokenized: InputName | ExecutorStep | LMMixtureDatasetConfig, - model_builder: ModelBuilder | None = None, - recipe: ScalingRecipe = MARIN_2025_RECIPE, - tokenizer: str = "stanford-crfm/marin-tokenizer", - seq_len: int = 4096, - override_output_path: str | None = None, - validation_sets: dict[str, TokenizerStep] | None = None, -) -> ExecutorStep: - """Create an ExecutorStep for one rung of the scaling ladder. - - This step depends on an IsoFLOP analysis step and will train a model - using the optimal configuration predicted from the scaling fits. - - Args: - name: Name for this executor step - analysis_step: The IsoFLOP analysis step to read fits from - target_budget: Target compute budget in FLOPs - label: Dataset label to use for scaling fit (e.g., 'nemo', 'comma') - tokenized: Tokenized dataset to train on. Can be an ExecutorStep, InputName, - or LMMixtureDatasetConfig. - model_builder: Function to build model config from CandidateConfig. - If None, uses default_model_builder (Qwen3Config). - recipe: ScalingRecipe with hyperparameters - tokenizer: Tokenizer to use - seq_len: Sequence length for training - override_output_path: Optional override for the output path - validation_sets: Optional validation sets for eval loss tracking - - Returns: - ExecutorStep configured to run one optimal training run - """ - if isinstance(tokenized, ExecutorStep): - resolved_tokenized: InputName | str | LMMixtureDatasetConfig = output_path_of(tokenized) - elif isinstance(tokenized, LMMixtureDatasetConfig): - resolved_tokenized = tokenized - else: - resolved_tokenized = tokenized - - output_path = override_output_path if override_output_path is not None else this_output_path() - - config = ScalingLadderRungConfig( - analysis_output_path=output_path_of(analysis_step), - target_budget=target_budget, - label=label, - tokenized=resolved_tokenized, - output_path=output_path, - model_builder=model_builder, - recipe=recipe, - tokenizer=tokenizer, - seq_len=seq_len, - validation_sets=validation_sets, - ) - - step = ExecutorStep( - name=os.path.join("checkpoints", name), - fn=run_scaling_ladder_rung, - config=config, - description=f"Scaling ladder rung: optimal training for {target_budget:.1e} FLOPs based on IsoFLOP analysis", - ) - - if override_output_path is not None: - step = step.with_output_path(override_output_path) - - return step - - -# ---------------- Scaling Ladder Suite ---------------- - - -@dataclass -class ScalingLadderSuite: - """A suite containing IsoFLOP analysis and scaling ladder rungs (optimal training steps). - - This is returned by `scaling_ladder_suite()` and contains all the steps - needed for end-to-end scaling ladder: IsoFLOP analysis + optimal training runs. - """ - - analysis: ExecutorStep - """The IsoFLOP analysis step.""" - - optimal_runs: list[ExecutorStep] - """Scaling ladder rungs: training steps for each target budget, using predicted optimal configs.""" - - @property - def all_steps(self) -> list[ExecutorStep]: - """All steps in the suite (analysis + optimal runs).""" - return [self.analysis, *self.optimal_runs] - - -def scaling_ladder_suite( - name: str, - training_runs: Sequence[ExecutorStep | InputName], - target_budgets: Sequence[float], - label: str, - tokenized: InputName | ExecutorStep | LMMixtureDatasetConfig, - model_builder: ModelBuilder | None = None, - recipe: ScalingRecipe = MARIN_2025_RECIPE, - tokenizer: str = "stanford-crfm/marin-tokenizer", - seq_len: int = 4096, - metric_key: str = "eval/paloma/c4_en/bpb", - label_map: dict[str, str] | None = None, - validation_sets: dict[str, TokenizerStep] | None = None, -) -> ScalingLadderSuite: - """Create a complete scaling ladder: IsoFLOP analysis + optimal training runs. - - This is the full pipeline interface that creates: - 1. An IsoFLOP analysis step that fits scaling laws - 2. Scaling ladder rungs (optimal training steps) for each target budget - - The optimal training steps depend on the analysis step and will train - models using compute-optimal configurations predicted from the scaling fits. - - Args: - name: Base name for the steps - training_runs: IsoFLOP training run ExecutorSteps to analyze - target_budgets: Target compute budgets (in FLOPs) for optimal training - label: Dataset label to use for scaling fit (e.g., 'nemo', 'comma') - tokenized: Tokenized dataset for optimal training runs. Can be an ExecutorStep, - InputName, or LMMixtureDatasetConfig. - model_builder: Function to build model config from CandidateConfig. - If None, uses default_model_builder (Qwen3Config). - recipe: ScalingRecipe with hyperparameters - tokenizer: Tokenizer to use - seq_len: Sequence length for training - metric_key: Which metric to use for loss - label_map: Optional mapping from experiment_name -> display label - validation_sets: Optional validation sets for eval loss tracking - - Returns: - ScalingLadderSuite containing the analysis step and optimal training steps - - Example: - >>> suite = scaling_ladder_suite( - ... name="nemo-scaling", - ... training_runs=isoflop_training_steps, - ... target_budgets=[1e21, 3e21, 1e22], - ... label="nemo", - ... tokenized=nemotron_tokenized, - ... ) - >>> all_steps = [*isoflop_training_steps, *suite.all_steps] - """ - analysis = isoflop_analysis_step( - name=f"{name}-analysis", - training_runs=training_runs, - metric_key=metric_key, - label_map=label_map, - recipe=recipe, - ) - - optimal_runs = [] - for budget in target_budgets: - run_step = scaling_ladder_rung_step( - name=f"{name}-optimal-{budget:.2e}", - analysis_step=analysis, - target_budget=budget, - label=label, - tokenized=tokenized, - model_builder=model_builder, - recipe=recipe, - tokenizer=tokenizer, - seq_len=seq_len, - validation_sets=validation_sets, - ) - optimal_runs.append(run_step) - - return ScalingLadderSuite( - analysis=analysis, - optimal_runs=optimal_runs, - ) From 13fdd1fe7c1d5ab69b88bd5c45e067c688c0a5ee Mon Sep 17 00:00:00 2001 From: Helw150 Date: Wed, 7 Jan 2026 16:46:55 -0800 Subject: [PATCH 49/79] Lint --- experiments/isoflop_sweep.py | 1 - 1 file changed, 1 deletion(-) diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index 65b82b67dc..11691f5ae8 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -50,7 +50,6 @@ ) from marin.scaling_laws import ScalingRecipe - # --- Scaling Recipe --- # This recipe encapsulates all model-specific hyperparameters for Marin scaling experiments. # Other experiments can define their own recipes by instantiating ScalingRecipe with different values. From 5343625f314273f998a087747a8f2d935557dde2 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Wed, 7 Jan 2026 20:43:25 -0800 Subject: [PATCH 50/79] Missing Tabs --- experiments/exp1603_subgroup_evals.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/experiments/exp1603_subgroup_evals.py b/experiments/exp1603_subgroup_evals.py index 872d397fe0..3fd6ef2c92 100644 --- a/experiments/exp1603_subgroup_evals.py +++ b/experiments/exp1603_subgroup_evals.py @@ -108,16 +108,16 @@ def create_eval_steps() -> list: ) steps.append(step) - logprobs_step = default_lm_log_probs( - output_path_of(model).cd("checkpoints"), - build_model_config(candidate), - dist_eval, - resource_config=ResourceConfig.with_tpu("v5p-8"), - checkpoint_is_hf=False, - name=versioned(f"{name}-DistRobust-ICE-logprobs"), - ) + logprobs_step = default_lm_log_probs( + output_path_of(model).cd("checkpoints"), + build_model_config(candidate), + dist_eval, + resource_config=ResourceConfig.with_tpu("v5p-8"), + checkpoint_is_hf=False, + name=versioned(f"{name}-DistRobust-ICE-logprobs"), + ) - steps.append(logprobs_step) + steps.append(logprobs_step) baselines = [ ("allenai/OLMo-2-1124-7B", "stage2-ingredient3-step8000-tokens34B"), From 3258061b65b0812671198590cf2542f38cb0c0bf Mon Sep 17 00:00:00 2001 From: Helw150 Date: Wed, 7 Jan 2026 21:44:22 -0800 Subject: [PATCH 51/79] Keep Moving Stuff into the Recipe in Experiments --- tests/test_scaling_laws.py | 90 +++++++++++--------------------------- 1 file changed, 26 insertions(+), 64 deletions(-) diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py index 7f9324c7d7..68402e0296 100644 --- a/tests/test_scaling_laws.py +++ b/tests/test_scaling_laws.py @@ -25,10 +25,10 @@ from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig from levanter.models.qwen import Qwen3Config -from marin.scaling_laws import default_model_builder +from marin.scaling_laws import ScalingRecipe from marin.scaling_laws.isoflop_analysis import ( + DEFAULT_SEQ_LEN, MARIN_TOKENIZER_VOCAB_SIZE, - IsoFlopSweepConfig, IsoFlopTrainArgs, candidate_configs, compute_training_flops, @@ -216,20 +216,22 @@ def test_parse_isoflop_run_name(): def test_candidate_configs_within_tolerance(): """Test that generated configs achieve the target FLOP budget within tolerance.""" - cfg = IsoFlopSweepConfig(flop_tolerance=0.01) + recipe = ScalingRecipe(name="test") budget = 1e19 - for candidate in candidate_configs(cfg, budget, MARIN_TOKENIZER_VOCAB_SIZE): - # Build model config from candidate to verify FLOPs - model_config = default_model_builder(candidate, cfg.seq_len) + flop_tolerance = 0.01 + seq_len = DEFAULT_SEQ_LEN + for candidate in candidate_configs(budget, MARIN_TOKENIZER_VOCAB_SIZE, recipe, flop_tolerance=flop_tolerance): + # Build model config from candidate using recipe + model_config = recipe.build_model_config(candidate.target_params, MARIN_TOKENIZER_VOCAB_SIZE, seq_len) achieved = compute_training_flops( model_config, MARIN_TOKENIZER_VOCAB_SIZE, candidate.batch_size, candidate.train_steps, - cfg.seq_len, + seq_len, ) relative_error = abs(achieved - budget) / budget - assert relative_error <= cfg.flop_tolerance + assert relative_error <= flop_tolerance # --- Curve fitting tests --- @@ -254,71 +256,34 @@ def test_robust_quad_logx_fits_quadratic(): # Snapshot of expected output for generate_isoflop_train_args with budget=3e18 training FLOPs. # Note: compute_training_flops includes the 3x multiplier for training (forward + backward pass), # matching how FLOPs are tracked in WandB via Levanter's log_performance_stats. +# +# CandidateConfig is now model-agnostic, containing only: +# - batch_size, train_steps, tokens, target_params, flops_budget EXPECTED_ISOFLOP_CONFIGS_3E18 = [ { - "hidden_size": 512, - "intermediate_dim": 2048, - "num_layers": 6, - "num_heads": 4, - "num_kv_heads": 4, "batch_size": 32, "train_steps": 32844, - "learning_rate": 0.003646, - "beta2": 0.994962, - "tpu_type": "v5p-8", - "run_name": "isoflop-3e+18-d512-L6-B32-test-snapshot", + "flops_budget": 3e18, }, { - "hidden_size": 640, - "intermediate_dim": 2560, - "num_layers": 7, - "num_heads": 5, - "num_kv_heads": 5, "batch_size": 16, "train_steps": 46274, - "learning_rate": 0.002063, - "beta2": 0.997478, - "tpu_type": "v5p-8", - "run_name": "isoflop-3e+18-d640-L7-B16-test-snapshot", + "flops_budget": 3e18, }, { - "hidden_size": 768, - "intermediate_dim": 3072, - "num_layers": 8, - "num_heads": 6, - "num_kv_heads": 6, "batch_size": 16, "train_steps": 33965, - "learning_rate": 0.001719, - "beta2": 0.997478, - "tpu_type": "v5p-8", - "run_name": "isoflop-3e+18-d768-L8-B16-test-snapshot", + "flops_budget": 3e18, }, { - "hidden_size": 896, - "intermediate_dim": 3584, - "num_layers": 10, - "num_heads": 7, - "num_kv_heads": 7, "batch_size": 8, "train_steps": 48105, - "learning_rate": 0.001042, - "beta2": 0.998738, - "tpu_type": "v5p-8", - "run_name": "isoflop-3e+18-d896-L10-B8-test-snapshot", + "flops_budget": 3e18, }, { - "hidden_size": 1024, - "intermediate_dim": 4096, - "num_layers": 11, - "num_heads": 8, - "num_kv_heads": 8, "batch_size": 8, "train_steps": 37335, - "learning_rate": 0.000912, - "beta2": 0.998738, - "tpu_type": "v5p-8", - "run_name": "isoflop-3e+18-d1024-L11-B8-test-snapshot", + "flops_budget": 3e18, }, ] @@ -329,12 +294,17 @@ def test_generate_isoflop_train_args_snapshot(): This test ensures the scaling_laws module produces identical configurations for reproducible isoflop sweeps. Uses 3e18 training FLOPs budget (which accounts for the 3x multiplier for forward + backward pass). + + CandidateConfig is now model-agnostic, so we only check the core compute + allocation parameters (batch_size, train_steps, flops_budget). """ - config = IsoFlopSweepConfig(budgets=(3e18,)) + recipe = ScalingRecipe(name="test-snapshot") + budgets = (3e18,) result = generate_isoflop_train_args( - sweep_config=config, + budgets=budgets, experiment_name="test-snapshot", vocab_size=MARIN_TOKENIZER_VOCAB_SIZE, + recipe=recipe, ) assert len(result) == len( @@ -345,17 +315,9 @@ def test_generate_isoflop_train_args_snapshot(): assert isinstance(args, IsoFlopTrainArgs) c = args.candidate actual = { - "hidden_size": c.hidden_size, - "intermediate_dim": c.intermediate_dim, - "num_layers": c.num_layers, - "num_heads": c.num_heads, - "num_kv_heads": c.num_kv_heads, "batch_size": c.batch_size, "train_steps": c.train_steps, - "learning_rate": round(c.learning_rate, 6), - "beta2": round(c.beta2, 6), - "tpu_type": args.tpu_type, - "run_name": args.run_name, + "flops_budget": c.flops_budget, } for key in expected: From fa453ae2025dd1d361656678399e091e4bac267f Mon Sep 17 00:00:00 2001 From: Helw150 Date: Wed, 7 Jan 2026 22:12:05 -0800 Subject: [PATCH 52/79] Name Consistently with Chinchilla --- experiments/exp1603_subgroup_evals.py | 22 +- .../exp2166_scaling_ladder_analysis.py | 6 - experiments/isoflop_sweep.py | 63 +--- lib/marin/src/marin/scaling_laws/__init__.py | 4 - .../marin/scaling_laws/isoflop_analysis.py | 172 ++++----- lib/marin/src/marin/scaling_laws/recipe.py | 216 ++++++----- .../src/marin/scaling_laws/scaling_ladder.py | 79 +--- lib/marin/src/marin/scaling_laws/tpu_utils.py | 32 +- tests/test_scaling_laws.py | 357 ++++-------------- 9 files changed, 308 insertions(+), 643 deletions(-) diff --git a/experiments/exp1603_subgroup_evals.py b/experiments/exp1603_subgroup_evals.py index 3fd6ef2c92..610c9d9e36 100644 --- a/experiments/exp1603_subgroup_evals.py +++ b/experiments/exp1603_subgroup_evals.py @@ -20,11 +20,14 @@ from experiments.llama import llama3_tokenizer from experiments.exp1342_gemstones_scaling_law import distributional_eval_sets -from experiments.isoflop_sweep import MARIN_SCALING_SUITES +from experiments.isoflop_sweep import MARIN_2025_RECIPE, MARIN_SCALING_SUITES from experiments.models import ModelConfig, download_model_step from marin.execution.executor import executor_main, output_path_of, versioned from marin.evaluation.log_probs import default_lm_log_probs -from marin.scaling_laws.isoflop_analysis import build_model_config +from marin.processing.tokenize import get_vocab_size_for_tokenizer + +# Vocab size for building model configs +VOCAB_SIZE = get_vocab_size_for_tokenizer("stanford-crfm/marin-tokenizer") # This is painfully slow to run in dry run mode # nodryrun @@ -45,7 +48,7 @@ def create_eval_steps() -> list: total_tokens = candidate.batch_size * candidate.train_steps * 4096 name = ( f"marin-nemo-{candidate.flops_budget:.0e}C-{total_tokens}T-" - f"{candidate.hidden_size}W-{candidate.num_layers}D" + f"N{candidate.target_params:.0e}" ) step = evaluate_levanter_lm_evaluation_harness( @@ -56,9 +59,10 @@ def create_eval_steps() -> list: ) steps.append(step) + model_config = MARIN_2025_RECIPE.build_model_config(candidate.target_params, VOCAB_SIZE) logprobs_step = default_lm_log_probs( output_path_of(model).cd("checkpoints"), - build_model_config(candidate), + model_config, dist_eval, resource_config=ResourceConfig.with_tpu("v5p-8"), checkpoint_is_hf=False, @@ -71,7 +75,7 @@ def create_eval_steps() -> list: total_tokens = candidate.batch_size * candidate.train_steps * 4096 name = ( f"marin-comma-{candidate.flops_budget:.0e}C-{total_tokens}T-" - f"{candidate.hidden_size}W-{candidate.num_layers}D" + f"N{candidate.target_params:.0e}" ) step = evaluate_levanter_lm_evaluation_harness( @@ -82,9 +86,10 @@ def create_eval_steps() -> list: ) steps.append(step) + model_config = MARIN_2025_RECIPE.build_model_config(candidate.target_params, VOCAB_SIZE) logprobs_step = default_lm_log_probs( output_path_of(model).cd("checkpoints"), - build_model_config(candidate), + model_config, dist_eval, resource_config=ResourceConfig.with_tpu("v5p-8"), checkpoint_is_hf=False, @@ -97,7 +102,7 @@ def create_eval_steps() -> list: total_tokens = candidate.batch_size * candidate.train_steps * 4096 name = ( f"marin-dclm-{candidate.flops_budget:.0e}C-{total_tokens}T-" - f"{candidate.hidden_size}W-{candidate.num_layers}D" + f"N{candidate.target_params:.0e}" ) step = evaluate_levanter_lm_evaluation_harness( @@ -108,9 +113,10 @@ def create_eval_steps() -> list: ) steps.append(step) + model_config = MARIN_2025_RECIPE.build_model_config(candidate.target_params, VOCAB_SIZE) logprobs_step = default_lm_log_probs( output_path_of(model).cd("checkpoints"), - build_model_config(candidate), + model_config, dist_eval, resource_config=ResourceConfig.with_tpu("v5p-8"), checkpoint_is_hf=False, diff --git a/experiments/exp2166_scaling_ladder_analysis.py b/experiments/exp2166_scaling_ladder_analysis.py index e0f784e9da..98d4fcbfaa 100644 --- a/experiments/exp2166_scaling_ladder_analysis.py +++ b/experiments/exp2166_scaling_ladder_analysis.py @@ -21,12 +21,6 @@ 1. Fits scaling laws from IsoFLOP sweep data to find compute-optimal configurations 2. Generates visualization plots (isoflop curves and scaling fit plots) 3. Optionally trains compute-optimal models at larger target budgets - -The analysis steps depend on completed isoflop training runs from isoflop_sweep.py. -Once complete, results are saved to the output path and uploaded to WandB. - -This experiment creates ExecutorSteps directly rather than using library factory -functions, following the pattern of isolating executor step creation to experiments. """ from experiments.defaults import default_validation_sets diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index 11691f5ae8..8e621d31d1 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -15,22 +15,12 @@ """Generate ISOFlop sweep steps for varying model sizes on a target dataset. This script constructs `ExecutorStep` objects that train models of different -sizes while keeping the total training FLOPs roughly constant. It is intended -as a lightweight scaffold for ISOFlop scaling law experiments. - -ExecutorSteps are created directly in this experiment file, following the pattern -of isolating executor step creation to experiments. The library provides: -- `generate_isoflop_train_args()`: Computes model/optimizer configs for each sweep point -- `ScalingRecipe`: Named hyperparameter bundle with architecture and optimizer settings - -This file uses those to create the actual ExecutorSteps. +sizes while keeping the total training FLOPs roughly constant. """ from dataclasses import replace from levanter.data.text import LMMixtureDatasetConfig -from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig -from levanter.models.qwen import Qwen3Config from experiments.evals.evals import default_eval from experiments.evals.task_configs import EvalTaskConfig @@ -46,34 +36,13 @@ from marin.scaling_laws import ( DEFAULT_BUDGETS, CandidateConfig, + ScalingRecipe, generate_isoflop_train_args, + pick_v5p_type, ) -from marin.scaling_laws import ScalingRecipe - -# --- Scaling Recipe --- -# This recipe encapsulates all model-specific hyperparameters for Marin scaling experiments. -# Other experiments can define their own recipes by instantiating ScalingRecipe with different values. MARIN_2025_RECIPE = ScalingRecipe(name="marin-2025") -"""Default Marin scaling recipe based on 2025 best practices.""" - - -def build_qwen3_from_candidate(candidate: CandidateConfig, seq_len: int = 4096) -> Qwen3Config: - """Build a Qwen3Config from a CandidateConfig. - - This is the experiment-level helper for constructing model configs. - Different experiments can use different model types (LlamaConfig, etc.) - by implementing their own builder function. - """ - return Qwen3Config( - hidden_dim=candidate.hidden_size, - intermediate_dim=candidate.intermediate_dim, - num_layers=candidate.num_layers, - num_heads=candidate.num_heads, - num_kv_heads=candidate.num_kv_heads, - max_seq_len=seq_len, - rope=Llama3RotaryEmbeddingsConfig(), - ) +"""Default Marin scaling recipe.""" def create_isoflop_sweep_steps( @@ -131,16 +100,20 @@ def create_isoflop_sweep_steps( # Create ExecutorSteps for each candidate configuration for args in train_args_list: - # Build model config from candidate (experiment controls model type) - model_config = build_qwen3_from_candidate(args.candidate, seq_len) + candidate = args.candidate + + # Build model and optimizer configs using the recipe + model_config = recipe.build_model_config(candidate.target_params, vocab_size, seq_len) + optimizer_config = recipe.build_optimizer_config(candidate, vocab_size) + tpu_type = pick_v5p_type(candidate, vocab_size, seq_len, recipe) train_cfg = replace( base_train_config, - train_batch_size=args.candidate.batch_size, - learning_rate=args.candidate.learning_rate, - num_train_steps=args.candidate.train_steps, - resources=ResourceConfig.with_tpu(args.tpu_type), - optimizer_config=args.optimizer_config, + train_batch_size=candidate.batch_size, + learning_rate=optimizer_config.learning_rate, + num_train_steps=candidate.train_steps, + resources=ResourceConfig.with_tpu(tpu_type), + optimizer_config=optimizer_config, ) # Create training step @@ -156,7 +129,7 @@ def create_isoflop_sweep_steps( # Pin to static output path for checkpoint reuse train_step = train_step.with_output_path(args.output_path) train_steps.append(train_step) - candidates.append(args.candidate) + candidates.append(candidate) # Create evaluation step if eval tasks specified if eval_tasks: @@ -198,10 +171,6 @@ def create_isoflop_sweep_steps( ) -# --- Scaling Suites --- -# Each suite explicitly specifies the recipe for visibility. -# ExecutorSteps are created by create_isoflop_sweep_steps() in this file. - MARIN_SCALING_SUITES = { "nemotron": create_isoflop_sweep_steps( tokenized=nemotron_mix, diff --git a/lib/marin/src/marin/scaling_laws/__init__.py b/lib/marin/src/marin/scaling_laws/__init__.py index c58ecd2f7e..eb6222ae34 100644 --- a/lib/marin/src/marin/scaling_laws/__init__.py +++ b/lib/marin/src/marin/scaling_laws/__init__.py @@ -46,9 +46,7 @@ ScalingRecipe, ) from marin.scaling_laws.scaling_ladder import ( - ModelBuilder, ScalingLadderRungConfig, - default_model_builder, run_scaling_ladder_rung, ) from marin.scaling_laws.scaling_plots import ( @@ -72,7 +70,6 @@ "IsoFlopAnalysisResult", "IsoFlopTrainArgs", "MinimaRecord", - "ModelBuilder", "QuadraticFitCoeffs", "ScalingFit", "ScalingLadderRungConfig", @@ -83,7 +80,6 @@ "compute_transformer_params", "create_isoflop_plot", "create_scaling_plot", - "default_model_builder", "estimate_memory_bytes", "fit_scaling_laws", "generate_isoflop_train_args", diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index c91f556cc8..260c16a516 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -14,28 +14,7 @@ """IsoFLOP analysis for finding compute-optimal training configurations. -This module provides functions and configs for IsoFLOP scaling law analysis. -Experiments create ExecutorSteps directly using the provided functions. - -Example usage in experiments: - - from marin.execution.executor import ExecutorStep, output_path_of - from marin.scaling_laws import ( - IsoFlopAnalysisConfig, - run_isoflop_analysis_step, - ) - - # Create analysis step - analysis_step = ExecutorStep( - name="my-scaling-analysis", - fn=run_isoflop_analysis_step, - config=IsoFlopAnalysisConfig( - training_runs=[output_path_of(r) for r in training_runs], - output_path="analysis/my-analysis", - ), - ) - -The analysis step will: +This module provides functions and configs for IsoFLOP scaling law analysis: 1. Read eval metrics from completed training runs 2. Fit scaling laws to find compute-optimal token counts 3. Save results to JSON/parquet files @@ -58,7 +37,6 @@ from jaxopt import ScipyMinimize from levanter.models.llama import LlamaConfig -from levanter.optim.config import OptimizerConfig from marin.scaling_laws.eval_metrics_reader import ( EvalMetricsAnalysisConfig, @@ -101,7 +79,7 @@ class ScalingFit(NamedTuple): - """Scaling law fit parameters for N* ~ A * C^alpha.""" + """Scaling law fit parameters for D* ~ A * C^alpha (optimal tokens ~ compute^alpha).""" alpha: float """Exponent in scaling law.""" @@ -140,56 +118,41 @@ class QuadraticFitCoeffs(NamedTuple): @dataclass class CandidateConfig: - """A candidate model/training configuration from the isoflop sweep. + """Model-agnostic compute allocation from scaling law analysis. + + Contains only the fundamental parameters that scaling laws reason about: + - How much compute (flops_budget) + - How to allocate it between model size (target_params) and data (tokens) + - Training batch configuration (batch_size, train_steps) - This dataclass contains all the information needed to create a training run. - Callers are responsible for converting this to their specific config format - (e.g., SimpleTrainConfig, Qwen3Config). + All model-specific details (architecture, optimizer hyperparameters) are + computed by the ScalingRecipe from these values. """ - hidden_size: int - intermediate_dim: int - num_layers: int - num_heads: int - num_kv_heads: int batch_size: int train_steps: int - learning_rate: float - beta2: float - tokens: float # total tokens = batch_size * train_steps * seq_len - flops_budget: float = 0.0 # the FLOP budget this config was generated for + tokens: float # = batch_size * train_steps * seq_len + target_params: int # Optimal parameter count for this flops_budget + flops_budget: float # Compute budget this config was generated for @dataclass class IsoFlopTrainArgs: """Arguments needed to set up an isoflop training run. - This dataclass contains the parameters needed for training. The caller is - responsible for constructing the model config from candidate parameters, - allowing flexibility in model type (Qwen3Config, LlamaConfig, etc.). + This dataclass contains the model-agnostic parameters needed for training. + The ScalingRecipe is responsible for converting these to model-specific + configs (model architecture, optimizer hyperparameters). Example: - >>> args = generate_isoflop_train_args(config, "my-exp", vocab_size)[0] - >>> # Caller constructs the model config - >>> model_config = Qwen3Config( - ... hidden_dim=args.candidate.hidden_size, - ... intermediate_dim=args.candidate.intermediate_dim, - ... num_layers=args.candidate.num_layers, - ... num_heads=args.candidate.num_heads, - ... num_kv_heads=args.candidate.num_kv_heads, - ... max_seq_len=4096, - ... rope=Llama3RotaryEmbeddingsConfig(), - ... ) + >>> args = generate_isoflop_train_args(budgets, "my-exp", vocab_size, recipe)[0] + >>> # Recipe converts candidate to model-specific configs + >>> model_config = recipe.build_model_config(args.candidate.target_params, vocab_size) + >>> optimizer_config = recipe.build_optimizer_config(args.candidate) """ candidate: CandidateConfig - """The candidate configuration with model/training hyperparameters.""" - - optimizer_config: OptimizerConfig - """Levanter optimizer configuration with learning_rate and beta2 set.""" - - tpu_type: str - """TPU slice type (e.g., 'v5p-8', 'v5p-32').""" + """Model-agnostic compute allocation (batch_size, train_steps, tokens, target_params).""" run_name: str """Name for the training run.""" @@ -206,16 +169,14 @@ class IsoFlopTrainArgs: @dataclass class MinimaRecord: - """Record of optimal configuration found at a specific (label, flops) point.""" + """Model-agnostic record of optimal configuration found at a specific (label, flops) point.""" label: str flops: float optimal_tokens: float loss_at_optimal: float - hidden_dim: int - num_layers: int - batch_size: int optimal_params: float + batch_size: int scaling_alpha: float | None = None scaling_A: float | None = None @@ -433,33 +394,25 @@ def candidate_configs( def _minima_to_candidates( minima_records: list[MinimaRecord], - recipe: ScalingRecipe, ) -> list[CandidateConfig]: - """Convert minima records to CandidateConfig objects. + """Convert minima records to model-agnostic CandidateConfig objects. This is used by both run_isoflop_analysis_step() and run_isoflop_analysis() to convert the fitted minima into usable candidate configs. Args: minima_records: List of optimal configurations from scaling law fits. - recipe: ScalingRecipe with architecture and hyperparameter settings. """ configs = [] for rec in minima_records: - if rec.hidden_dim == 0: + if rec.optimal_params == 0: continue configs.append( CandidateConfig( - hidden_size=rec.hidden_dim, - intermediate_dim=rec.hidden_dim * recipe.mlp_ratio, - num_layers=rec.num_layers, - num_heads=max(1, rec.hidden_dim // recipe.hidden_head_ratio), - num_kv_heads=max(1, rec.hidden_dim // recipe.hidden_head_ratio), batch_size=rec.batch_size, train_steps=int(rec.optimal_tokens / (rec.batch_size * SEQ_LEN)), - learning_rate=recipe.compute_learning_rate(rec.batch_size, rec.hidden_dim), - beta2=recipe.compute_beta2(rec.batch_size), tokens=rec.optimal_tokens, + target_params=int(rec.optimal_params), flops_budget=rec.flops, ) ) @@ -477,13 +430,13 @@ def generate_isoflop_train_args( seq_len: int = DEFAULT_SEQ_LEN, steps_per_run: int = DEFAULT_STEPS_PER_RUN, flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, - base_optimizer_config: OptimizerConfig | None = None, ) -> list[IsoFlopTrainArgs]: - """Generate training arguments for each candidate in an isoflop sweep. + """Generate model-agnostic training arguments for each candidate in an isoflop sweep. This is a convenience function that delegates to recipe.generate_isoflop_train_args(). - The recipe encapsulates all model-specific decisions, while this function provides - backward compatibility. + Returns IsoFlopTrainArgs containing model-agnostic CandidateConfig objects. + Use recipe.build_model_config() and recipe.build_optimizer_config() to get + model-specific configs. Args: budgets: Sequence of FLOP budgets to generate configs for. @@ -493,27 +446,26 @@ def generate_isoflop_train_args( seq_len: Sequence length for training. steps_per_run: Reference step count for FLOP budget calculation. flop_tolerance: Tolerance for matching FLOP budget. - base_optimizer_config: Base optimizer config to modify. If None, uses recipe defaults. Returns: List of IsoFlopTrainArgs, one per candidate config across all budgets. Example: - >>> from marin.scaling_laws import generate_isoflop_train_args, DEFAULT_BUDGETS + >>> from marin.scaling_laws import generate_isoflop_train_args, DEFAULT_BUDGETS, ScalingRecipe + >>> recipe = ScalingRecipe(name="my-recipe") >>> train_args = generate_isoflop_train_args( ... budgets=DEFAULT_BUDGETS, ... experiment_name="my-experiment", ... vocab_size=128256, + ... recipe=recipe, ... ) >>> for args in train_args: - ... # Caller constructs the model config from candidate parameters - ... model_config = Qwen3Config( - ... hidden_dim=args.candidate.hidden_size, - ... # ... etc - ... ) + ... # Recipe converts model-agnostic candidate to model-specific configs + ... model_config = recipe.build_model_config(args.candidate.target_params, vocab_size) + ... optimizer_config = recipe.build_optimizer_config(args.candidate, vocab_size) """ return recipe.generate_isoflop_train_args( - budgets, experiment_name, vocab_size, seq_len, steps_per_run, flop_tolerance, base_optimizer_config + budgets, experiment_name, vocab_size, seq_len, steps_per_run, flop_tolerance ) @@ -523,22 +475,32 @@ def generate_isoflop_train_args( def parse_isoflop_run_name(run_name: str) -> str | None: """Parse experiment name from isoflop run name. - Expected format: isoflop-{budget}-d{hidden}-L{layers}-B{batch}-{experiment_name} + Supports two formats: + - New: isoflop-{budget}-N{params}-B{batch}-{experiment_name} + E.g., 'isoflop-1e+18-N1e+08-B128-nemo-wider-depth-adapt' + - Legacy: isoflop-{budget}-d{hidden}-L{layers}-B{batch}-{experiment_name} + E.g., 'isoflop-1e+19-d2048-L16-B1024-nemo-wider-depth-adapt' + Optionally with a trailing - which is ignored. - E.g., 'isoflop-1e+19-d2048-L16-B1024-nemo-wider-depth-adapt' - or 'isoflop-1e+19-d2048-L16-B1024-nemo-wider-depth-adapt-a1b2c3' Returns experiment_name or None if parsing fails. """ # Strip optional - suffix run_name = re.sub(r"-[0-9a-fA-F]{6}$", "", run_name) - pattern = r"isoflop-(?:[0-9.e+]+)-d(?:\d+)-L(?:\d+)-B(?:\d+)-(.+)" - match = re.match(pattern, run_name) - if not match: - return None + # New format: isoflop-{budget}-N{params}-B{batch}-{experiment_name} + new_pattern = r"isoflop-(?:[0-9.e+]+)-N(?:[0-9.e+]+)-B(?:\d+)-(.+)" + match = re.match(new_pattern, run_name) + if match: + return match.group(1) + + # Legacy format: isoflop-{budget}-d{hidden}-L{layers}-B{batch}-{experiment_name} + legacy_pattern = r"isoflop-(?:[0-9.e+]+)-d(?:\d+)-L(?:\d+)-B(?:\d+)-(.+)" + match = re.match(legacy_pattern, run_name) + if match: + return match.group(1) - return match.group(1) + return None def robust_quad_logx(x: jnp.ndarray, y: jnp.ndarray, delta: float = 1.0) -> tuple[float, float, float]: @@ -624,27 +586,25 @@ def fit_scaling_laws( if a == 0: continue - L_opt = -b / (2 * a) - N_star = float(10**L_opt) - loss_opt = float(a * L_opt**2 + b * L_opt + c) + log_D_opt = -b / (2 * a) + D_star = float(10**log_D_opt) + loss_opt = float(a * log_D_opt**2 + b * log_D_opt + c) - idx = (sub.tokens - N_star).abs().argmin() + idx = (sub.tokens - D_star).abs().argmin() nearest_row = sub.iloc[idx] minima_records.append( MinimaRecord( label=lab, flops=float(C), - optimal_tokens=N_star, + optimal_tokens=D_star, loss_at_optimal=loss_opt, - hidden_dim=int(nearest_row["hidden_dim"]), - num_layers=int(nearest_row["num_layers"]), + optimal_params=float(nearest_row.get("params") or C / (6 * D_star)), batch_size=int(nearest_row["batch_size"]), - optimal_params=float(nearest_row.get("params") or C / (6 * N_star)), ) ) - # Fit scaling law N* ~ A * C^alpha per dataset + # Fit scaling law D* ~ A * C^alpha per dataset (optimal tokens ~ compute^alpha) scaling_fits: dict[str, ScalingFit] = {} by_lab: dict[str, list[MinimaRecord]] = {} for rec in minima_records: @@ -657,9 +617,9 @@ def fit_scaling_laws( recs = sorted(recs, key=lambda r: r.flops) Cs = jnp.array([r.flops for r in recs]) - Ns = jnp.array([r.optimal_tokens for r in recs]) + Ds = jnp.array([r.optimal_tokens for r in recs]) - alpha, logA = jnp.polyfit(jnp.log10(Cs), jnp.log10(Ns), 1) + alpha, logA = jnp.polyfit(jnp.log10(Cs), jnp.log10(Ds), 1) A = float(10**logA) alpha = float(alpha) scaling_fits[lab] = ScalingFit(alpha, A) @@ -930,7 +890,7 @@ def run_isoflop_analysis_step(config: IsoFlopAnalysisConfig) -> None: for label, (alpha, A) in fit_result.scaling_fits.items(): logger.info(f" {label}: N* = {A:.2e} * C^{alpha:.3f}") - configs = _minima_to_candidates(fit_result.minima_records, config.recipe) + configs = _minima_to_candidates(fit_result.minima_records) result = IsoFlopAnalysisResult( configs=configs, @@ -1016,7 +976,7 @@ def run_isoflop_analysis( logger.info(f"Transformed {len(isoflop_df)} runs for scaling ladder analysis") fit_result = fit_scaling_laws(isoflop_df) - configs = _minima_to_candidates(fit_result.minima_records, recipe) + configs = _minima_to_candidates(fit_result.minima_records) return IsoFlopAnalysisResult( configs=configs, diff --git a/lib/marin/src/marin/scaling_laws/recipe.py b/lib/marin/src/marin/scaling_laws/recipe.py index 6a290fcc8f..810e865436 100644 --- a/lib/marin/src/marin/scaling_laws/recipe.py +++ b/lib/marin/src/marin/scaling_laws/recipe.py @@ -14,36 +14,19 @@ """Scaling recipes: model-specific hyperparameter bundles for scaling law experiments. -A ScalingRecipe encapsulates ALL model-specific decisions for scaling experiments: -- Architecture formula (how to compute num_layers from hidden_size) +A ScalingRecipe encapsulates model-specific decisions for scaling experiments: +- Architecture formula (how to compute architecture from target param count) - Architecture ratios (MLP width, head size) - Model config building (returns LlamaConfig or subclass) - Learning rate and optimizer hyperparameters - Search bounds and constraints for isoflop sweeps - Candidate generation for isoflop sweeps - -The scaling_laws module provides model-agnostic utilities (FLOP math, scaling law -fitting), while the recipe provides the model-specific "driver" that uses them. - -Experiments should define their own recipe instances: - - from marin.scaling_laws import ScalingRecipe - - MY_RECIPE = ScalingRecipe(name="my-experiment") - - # Generate candidates for a FLOP budget - for candidate in MY_RECIPE.candidate_configs(budget=1e18, vocab_size=128256): - print(candidate) - - # Build model and optimizer configs - model_config = MY_RECIPE.build_model_config(hidden_size=1024) - optimizer_config = MY_RECIPE.build_optimizer_config(lr=0.001, beta2=0.95) """ import math import os from collections.abc import Iterator, Sequence -from dataclasses import dataclass, replace +from dataclasses import dataclass from typing import TYPE_CHECKING from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig @@ -67,13 +50,10 @@ def _round_to_power_of_two(x: float) -> int: - """Round to the nearest power of 2.""" - if x <= 0: + """Round x UP to the nearest power of 2.""" + if x <= 1: return 1 - log2_x = math.log2(x) - lower = 2 ** int(log2_x) - upper = 2 ** (int(log2_x) + 1) - return lower if (x - lower) < (upper - x) else upper + return 2 ** math.ceil(math.log2(x)) @dataclass(frozen=True) @@ -189,20 +169,95 @@ def get_step_size(self, budget: float) -> int: return self.large_budget_step_size return self.small_budget_step_size + # --- Parameter count estimation --- + + def compute_params_for_hidden_size(self, hidden_size: int, vocab_size: int) -> int: + """Compute approximate parameter count for a given hidden size. + + This uses the standard transformer parameter formula for Llama/Qwen architectures. + """ + num_layers = self.compute_num_layers(hidden_size) + intermediate_dim = hidden_size * self.mlp_ratio + n_heads = max(1, hidden_size // self.hidden_head_ratio) + head_size = hidden_size // n_heads + + # Embeddings + embed_params = vocab_size * hidden_size * 2 # input + output embeddings + + # Per-layer params: attention + mlp + layer norms + q_proj = hidden_size * head_size * n_heads + kv_proj = 2 * hidden_size * head_size * n_heads # K and V + o_proj = head_size * n_heads * hidden_size + attn_params = q_proj + kv_proj + o_proj + + mlp_params = 3 * hidden_size * intermediate_dim # gate, up, down + norm_params = 2 * hidden_size # 2 layer norms per layer + + layer_params = attn_params + mlp_params + norm_params + total_layer_params = num_layers * layer_params + + # Final layer norm + final_norm = hidden_size + + return embed_params + total_layer_params + final_norm + + def hidden_size_for_params(self, target_params: int, vocab_size: int) -> int: + """Find the hidden size that gives approximately target_params. + + Uses binary search over valid hidden sizes. + """ + min_hidden = 2**self.min_hidden_pow + max_hidden = 2**self.max_hidden_pow + + best_hidden = min_hidden + best_diff = abs(self.compute_params_for_hidden_size(min_hidden, vocab_size) - target_params) + + # Search in steps of 64 for efficiency + for hidden_size in range(min_hidden, max_hidden + 1, 64): + params = self.compute_params_for_hidden_size(hidden_size, vocab_size) + diff = abs(params - target_params) + if diff < best_diff: + best_diff = diff + best_hidden = hidden_size + + return best_hidden + # --- Model config building --- - def build_model_config(self, hidden_size: int, seq_len: int = DEFAULT_SEQ_LEN) -> LlamaConfig: - """Build a model config from hidden_size using this recipe's architecture formula. + def build_model_config(self, target_params: int, vocab_size: int, seq_len: int = DEFAULT_SEQ_LEN) -> LlamaConfig: + """Build a model config for a target parameter count. + + The recipe determines the architecture (hidden_size, num_layers, etc.) + that achieves approximately target_params parameters. Args: - hidden_size: Model hidden dimension. + target_params: Target parameter count. + vocab_size: Vocabulary size. seq_len: Maximum sequence length. Returns: A LlamaConfig (or subclass) with architecture determined by this recipe. """ - # TODO: Currently returns Qwen3Config which inherits from LlamaConfig. - # This could be parameterized to return different model types. + hidden_size = self.hidden_size_for_params(target_params, vocab_size) + num_layers = self.compute_num_layers(hidden_size) + intermediate_dim = hidden_size * self.mlp_ratio + n_heads = max(1, hidden_size // self.hidden_head_ratio) + + return Qwen3Config( + hidden_dim=hidden_size, + intermediate_dim=intermediate_dim, + num_layers=num_layers, + num_heads=n_heads, + num_kv_heads=n_heads, + max_seq_len=seq_len, + rope=Llama3RotaryEmbeddingsConfig(), + ) + + def _build_model_config_from_hidden_size(self, hidden_size: int, seq_len: int = DEFAULT_SEQ_LEN) -> LlamaConfig: + """Internal: build model config from hidden_size directly. + + Used during candidate generation when we're iterating over hidden sizes. + """ num_layers = self.compute_num_layers(hidden_size) intermediate_dim = hidden_size * self.mlp_ratio n_heads = max(1, hidden_size // self.hidden_head_ratio) @@ -217,16 +272,22 @@ def build_model_config(self, hidden_size: int, seq_len: int = DEFAULT_SEQ_LEN) - rope=Llama3RotaryEmbeddingsConfig(), ) - def build_optimizer_config(self, learning_rate: float, beta2: float) -> OptimizerConfig: - """Build optimizer config using this recipe's hyperparameters. + def build_optimizer_config(self, candidate: "CandidateConfig", vocab_size: int) -> OptimizerConfig: + """Build optimizer config for a candidate. + + Computes learning rate and beta2 from the candidate's batch_size and target_params. Args: - learning_rate: Learning rate. - beta2: Adam beta2. + candidate: Model-agnostic candidate config. + vocab_size: Vocabulary size (needed to determine hidden_size). Returns: An OptimizerConfig with settings from this recipe. """ + hidden_size = self.hidden_size_for_params(candidate.target_params, vocab_size) + learning_rate = self.compute_learning_rate(candidate.batch_size, hidden_size) + beta2 = self.compute_beta2(candidate.batch_size) + return CautiousConfig( learning_rate=learning_rate, weight_decay=self.weight_decay, @@ -251,12 +312,14 @@ def candidate_configs( steps_per_run: int = DEFAULT_STEPS_PER_RUN, flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, ) -> "Iterator[CandidateConfig]": - """Yield candidate model configurations within the FLOP budget. + """Yield model-agnostic candidate configurations within the FLOP budget. + + This method encapsulates the model-specific search logic internally but + returns model-agnostic CandidateConfig objects containing only: + batch_size, train_steps, tokens, target_params, flops_budget. - This method encapsulates the model-specific search logic: - - Iterates over hidden_sizes using the recipe's search bounds - - Builds model configs using the recipe's architecture formula - - Applies the recipe's LR and batch size constraints + The caller uses recipe.build_model_config() and recipe.build_optimizer_config() + to convert these to model-specific configs. Args: budget: Target FLOP budget. @@ -266,7 +329,7 @@ def candidate_configs( flop_tolerance: Tolerance for matching FLOP budget (relative error). Yields: - CandidateConfig objects for each valid configuration. + Model-agnostic CandidateConfig objects for each valid configuration. """ # Import here to avoid circular dependency from marin.scaling_laws.isoflop_analysis import CandidateConfig, solve_for_batch_size, solve_for_train_steps @@ -276,49 +339,36 @@ def candidate_configs( max_hidden = 2**self.max_hidden_pow for hidden_size in range(min_hidden, max_hidden + 1, step_size): - # Build model config using recipe's architecture formula - model_config = self.build_model_config(hidden_size, seq_len) + model_config = self._build_model_config_from_hidden_size(hidden_size, seq_len) - # Use model-agnostic FLOP utilities to solve for batch/steps - # TODO: LlamaConfig.flops_per_token() provides the FLOP calculation batch_exact = solve_for_batch_size(model_config, vocab_size, budget, steps_per_run, seq_len) batch_size = _round_to_power_of_two(batch_exact) - # Apply LR constraint (recipe-specific) + # Apply LR constraint lr = self.compute_learning_rate(batch_size, hidden_size) while lr > self.max_learning_rate: batch_size //= 2 lr = self.compute_learning_rate(batch_size, hidden_size) - # Apply min batch constraint (recipe-specific) if batch_size < self.min_batch_size: continue - # Solve for steps to hit budget with chosen batch train_steps = round(solve_for_train_steps(model_config, vocab_size, budget, batch_size, seq_len)) # Verify we hit the budget within tolerance # Training FLOPs = 3 * flops_per_token * batch * steps * seq_len - # The 3x multiplier accounts for forward (1x) + backward (2x) pass achieved_flops = 3 * model_config.flops_per_token(vocab_size, seq_len) * batch_size * train_steps * seq_len if abs(achieved_flops - budget) / budget > flop_tolerance: continue - # Compute optimizer hyperparameters (recipe-specific) - beta2 = self.compute_beta2(batch_size) tokens = batch_size * train_steps * seq_len + target_params = self.compute_params_for_hidden_size(hidden_size, vocab_size) yield CandidateConfig( - hidden_size=hidden_size, - intermediate_dim=model_config.intermediate_dim, - num_layers=model_config.num_layers, - num_heads=model_config.num_heads, - num_kv_heads=model_config.num_kv_heads, batch_size=batch_size, train_steps=train_steps, - learning_rate=lr, - beta2=beta2, tokens=tokens, + target_params=target_params, flops_budget=budget, ) @@ -330,12 +380,12 @@ def generate_isoflop_train_args( seq_len: int = DEFAULT_SEQ_LEN, steps_per_run: int = DEFAULT_STEPS_PER_RUN, flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, - base_optimizer_config: OptimizerConfig | None = None, ) -> "list[IsoFlopTrainArgs]": - """Generate training arguments for each candidate in an isoflop sweep. + """Generate model-agnostic training arguments for each candidate in an isoflop sweep. - This method generates all the arguments needed to set up training runs for - each candidate configuration in the sweep. + Returns IsoFlopTrainArgs containing the model-agnostic CandidateConfig plus + naming information. The caller uses recipe.build_model_config() and + recipe.build_optimizer_config() to get model-specific configs. Args: budgets: Sequence of FLOP budgets to generate configs for. @@ -344,57 +394,27 @@ def generate_isoflop_train_args( seq_len: Sequence length for training. steps_per_run: Reference step count for FLOP budget calculation. flop_tolerance: Tolerance for matching FLOP budget. - base_optimizer_config: Base optimizer config to modify. If None, uses recipe defaults. Returns: List of IsoFlopTrainArgs, one per candidate config across all budgets. """ - # Import here to avoid circular dependency from marin.scaling_laws.isoflop_analysis import IsoFlopTrainArgs - from marin.scaling_laws.tpu_utils import pick_v5p_type - - if base_optimizer_config is None: - base_optimizer_config = CautiousConfig( - learning_rate=1.0, # Placeholder, will be overridden - weight_decay=self.weight_decay, - min_lr_ratio=self.min_lr_ratio, - warmup=self.warmup, - beta1=self.beta1, - beta2=0.98, # Placeholder, will be overridden - epsilon=self.epsilon, - max_grad_norm=self.max_grad_norm, - adamc_weight_decay=True, - lr_schedule=self.lr_schedule, - decay=self.decay, - ) results: list[IsoFlopTrainArgs] = [] for budget in budgets: for candidate in self.candidate_configs(budget, vocab_size, seq_len, steps_per_run, flop_tolerance): - # Pick TPU type based on candidate parameters - tpu_type = pick_v5p_type(candidate, vocab_size, seq_len) - - # Build optimizer config with candidate-specific LR and beta2 - optimizer_cfg = replace( - base_optimizer_config, - learning_rate=candidate.learning_rate, - beta2=candidate.beta2, - ) - - # Generate run name and tags run_name = ( - f"isoflop-{budget:.0e}-d{candidate.hidden_size}-" - f"L{candidate.num_layers}-B{candidate.batch_size}-{experiment_name}" + f"isoflop-{budget:.0e}-N{candidate.target_params:.0e}-" + f"B{candidate.batch_size}-{experiment_name}" ) tags = ( f"FLOPs={budget:.1e}", - f"d={candidate.hidden_size}", - f"L={candidate.num_layers}", + f"N={candidate.target_params:.1e}", f"B={candidate.batch_size}", f"steps={candidate.train_steps}", - f"tpu={tpu_type}", + f"tokens={candidate.tokens:.1e}", ) # Static output path for checkpoint reuse @@ -403,8 +423,6 @@ def generate_isoflop_train_args( results.append( IsoFlopTrainArgs( candidate=candidate, - optimizer_config=optimizer_cfg, - tpu_type=tpu_type, run_name=run_name, tags=tags, output_path=output_path, @@ -468,7 +486,7 @@ def predict_optimal_config( best = max(candidates, key=lambda c: c.tokens) logger.info( - f"Selected config: d={best.hidden_size}, L={best.num_layers}, " + f"Selected config: N={best.target_params:.2e}, " f"B={best.batch_size}, tokens={best.tokens:.2e} (optimal: {optimal_tokens:.2e})" ) diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py index 28f5d2c889..bbadb88ffc 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -15,35 +15,12 @@ """Scaling ladder: compute-optimal training runs based on IsoFLOP analysis. This module provides functions and configs for training models with compute-optimal -configurations derived from IsoFLOP analysis. Experiments create ExecutorSteps -directly using the provided functions. - -Example usage in experiments: - - from marin.execution.executor import ExecutorStep, output_path_of - from marin.scaling_laws import ( - ScalingLadderRungConfig, - run_scaling_ladder_rung, - ) - - # Create optimal training step that depends on analysis output - optimal_step = ExecutorStep( - name="optimal-1e21", - fn=run_scaling_ladder_rung, - config=ScalingLadderRungConfig( - analysis_output_path=output_path_of(analysis_step), - target_budget=1e21, - label="nemo", - tokenized=my_tokenized_dataset, - output_path="checkpoints/optimal-1e21", - ), - ) +configurations derived from IsoFLOP analysis. """ import json import logging import os -from collections.abc import Callable from dataclasses import dataclass from datetime import timedelta @@ -53,10 +30,7 @@ from haliax.partitioning import ResourceAxis from levanter.checkpoint import CheckpointerConfig from levanter.data.text import LMMixtureDatasetConfig -from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig from levanter.main.train_lm import TrainLmConfig -from levanter.models.lm_model import LmConfig -from levanter.models.qwen import Qwen3Config from levanter.tracker.wandb import WandbConfig from levanter.trainer import TrainerConfig from levanter.utils.mesh import MeshConfig @@ -64,7 +38,6 @@ from marin.processing.tokenize import get_vocab_size_for_tokenizer from marin.processing.tokenize.data_configs import add_validation_sets_to_mixture, lm_data_config from marin.scaling_laws.isoflop_analysis import ( - CandidateConfig, ScalingFit, predict_optimal_config, ) @@ -74,27 +47,6 @@ logger = logging.getLogger(__name__) -# Type alias for model builder callbacks -# Takes (candidate, seq_len) and returns a model config -ModelBuilder = Callable[[CandidateConfig, int], LmConfig] - - -def default_model_builder(candidate: CandidateConfig, seq_len: int) -> Qwen3Config: - """Default model builder that creates Qwen3Config. - - This is provided as a convenience for the common case. Users can pass - their own model_builder function to use different model types. - """ - return Qwen3Config( - hidden_dim=candidate.hidden_size, - intermediate_dim=candidate.intermediate_dim, - num_layers=candidate.num_layers, - num_heads=candidate.num_heads, - num_kv_heads=candidate.num_kv_heads, - max_seq_len=seq_len, - rope=Llama3RotaryEmbeddingsConfig(), - ) - def _prepare_data_config( tokenized: str | LMMixtureDatasetConfig, @@ -136,6 +88,8 @@ class ScalingLadderRungConfig: This config references an IsoFLOP analysis output and specifies the target compute budget. At runtime, the optimal config is loaded from the analysis output. + + The ScalingRecipe handles all model-specific decisions (architecture, optimizer). """ analysis_output_path: str @@ -154,10 +108,7 @@ class ScalingLadderRungConfig: """Where to write training outputs.""" recipe: ScalingRecipe - """Scaling recipe with hyperparameters.""" - - model_builder: ModelBuilder | None = None - """Function to build model config from CandidateConfig. If None, uses default_model_builder (Qwen3).""" + """Scaling recipe that handles model/optimizer config building.""" tokenizer: str = "stanford-crfm/marin-tokenizer" """Tokenizer to use.""" @@ -170,7 +121,12 @@ class ScalingLadderRungConfig: def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: - """Run one rung of the scaling ladder (one compute-optimal training run).""" + """Run one rung of the scaling ladder (one compute-optimal training run). + + The recipe handles all model-specific decisions: + - Model config is built via `recipe.build_model_config(target_params, vocab_size)` + - Optimizer config is built via `recipe.build_optimizer_config(candidate, vocab_size)` + """ result_path = os.path.join(config.analysis_output_path, "isoflop_analysis_result.json") fs, _, _ = fsspec.get_fs_token_paths(result_path) @@ -201,18 +157,14 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: logger.info( f"Training with optimal config for {config.target_budget:.2e} FLOPs:\n" - f" hidden_size={candidate.hidden_size}, num_layers={candidate.num_layers}\n" + f" target_params={candidate.target_params:.2e}\n" f" batch_size={candidate.batch_size}, train_steps={candidate.train_steps}\n" - f" learning_rate={candidate.learning_rate:.6f}, tokens={candidate.tokens:.2e}" + f" tokens={candidate.tokens:.2e}" ) - # Use provided model builder or default to Qwen3 - model_builder = config.model_builder or default_model_builder - model_cfg = model_builder(candidate, config.seq_len) - - tpu_type = pick_v5p_type(candidate, vocab_size, config.seq_len) - - optimizer_cfg = config.recipe.build_optimizer_config(candidate.learning_rate, candidate.beta2) + model_cfg = config.recipe.build_model_config(candidate.target_params, vocab_size, config.seq_len) + optimizer_cfg = config.recipe.build_optimizer_config(candidate, vocab_size) + tpu_type = pick_v5p_type(candidate, vocab_size, config.seq_len, config.recipe) pretraining_data = _prepare_data_config(config.tokenized, config.validation_sets) @@ -225,6 +177,7 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: "optimal-training", f"FLOPs={config.target_budget:.1e}", f"label={config.label}", + f"N={candidate.target_params:.1e}", ], ), mp=jmp.get_policy("p=f32,c=bfloat16"), diff --git a/lib/marin/src/marin/scaling_laws/tpu_utils.py b/lib/marin/src/marin/scaling_laws/tpu_utils.py index 5b645066de..466cf6af2d 100644 --- a/lib/marin/src/marin/scaling_laws/tpu_utils.py +++ b/lib/marin/src/marin/scaling_laws/tpu_utils.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: from marin.scaling_laws.isoflop_analysis import CandidateConfig + from marin.scaling_laws.recipe import ScalingRecipe # ---------------- TPU v5p Hardware Constants ---------------- # These constants are specific to TPU v5p pods. @@ -84,16 +85,15 @@ def pick_v5p_type( candidate: "CandidateConfig", vocab_size: int, seq_len: int, + recipe: "ScalingRecipe | None" = None, ) -> str: """Select the smallest TPU v5p slice that fits the model in float32. - Uses conservative memory estimation to select a TPU slice size that - can accommodate the model parameters, optimizer states, and activations. - Args: - candidate: CandidateConfig with model architecture parameters. + candidate: CandidateConfig with target_params and batch_size. vocab_size: Vocabulary size. seq_len: Sequence length. + recipe: ScalingRecipe to determine architecture. If None, uses default. Returns: TPU slice name, e.g., "v5p-8" or "v5p-32". @@ -101,21 +101,17 @@ def pick_v5p_type( Raises: ValueError: If the model is too large for available v5p slices. """ - # Import here to avoid circular dependency - from marin.scaling_laws.isoflop_analysis import compute_transformer_params - - param_count = compute_transformer_params( - hidden_dim=candidate.hidden_size, - intermediate_dim=candidate.intermediate_dim, - num_layers=candidate.num_layers, - num_heads=candidate.num_heads, - num_kv_heads=candidate.num_kv_heads, - vocab_size=vocab_size, - ) + if recipe is None: + from marin.scaling_laws.recipe import ScalingRecipe + recipe = ScalingRecipe(name="default") + + hidden_size = recipe.hidden_size_for_params(candidate.target_params, vocab_size) + num_layers = recipe.compute_num_layers(hidden_size) + need_bytes = estimate_memory_bytes( - param_count, - candidate.hidden_size, - candidate.num_layers, + candidate.target_params, + hidden_size, + num_layers, candidate.batch_size, seq_len, vocab_size, diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py index 68402e0296..d4315602da 100644 --- a/tests/test_scaling_laws.py +++ b/tests/test_scaling_laws.py @@ -14,13 +14,12 @@ """Unit tests for the scaling_laws module. -These tests focus on integration and end-to-end validation with specific expected outputs, -particularly the snapshot test which ensures reproducibility of config generation. +These tests focus on integration and behavioral validation, particularly +the snapshot test which ensures reproducibility of config generation. """ import jax.numpy as jnp import pandas as pd -import pytest from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig from levanter.models.qwen import Qwen3Config @@ -36,91 +35,17 @@ generate_isoflop_train_args, parse_isoflop_run_name, robust_quad_logx, - round_flops_to_bucket, - round_to_power_of_two, solve_for_batch_size, solve_for_train_steps, transform_metrics_for_isoflop, ) -# --- Utility function tests (parametrized) --- - - -@pytest.mark.parametrize( - "value,expected", - [ - # Exact powers unchanged - (1, 1), - (2, 2), - (4, 4), - (16, 16), - # Non-powers round up - (3, 4), - (5, 8), - (9, 16), - # Small/zero values become 1 - (0.5, 1), - (0, 1), - # Large values - (100, 128), - (1000, 1024), - ], -) -def test_round_to_power_of_two(value, expected): - """Test round_to_power_of_two produces correct results.""" - assert round_to_power_of_two(value) == expected - - -@pytest.mark.parametrize( - "value,expected", - [ - # Exact values unchanged - (1e18, 1e18), - (1e19, 1e19), - (3e19, 3e19), - # Rounds to 1 significant figure - (1.05e19, 1e19), - (1.4e19, 1e19), - (1.5e19, 2e19), - (2.8e19, 3e19), - (9.5e19, 1e20), - # Edge cases - (0, 0), - ], -) -def test_round_flops_to_bucket(value, expected): - """Test round_flops_to_bucket rounds to 1 significant figure.""" - assert round_flops_to_bucket(value) == expected - # --- FLOP computation tests --- -def test_compute_training_flops_linear_in_batch_and_steps(): - """Test that FLOPs scale linearly with batch size and steps.""" - # Build a model config for testing - model_config = Qwen3Config( - max_seq_len=4096, - hidden_dim=512, - intermediate_dim=2048, - num_heads=8, - num_kv_heads=8, - num_layers=12, - rope=Llama3RotaryEmbeddingsConfig(), - ) - vocab_size = 128256 - seq_len = 4096 - - base_flops = compute_training_flops(model_config, vocab_size, 32, 1000, seq_len) - double_batch_flops = compute_training_flops(model_config, vocab_size, 64, 1000, seq_len) - double_steps_flops = compute_training_flops(model_config, vocab_size, 32, 2000, seq_len) - - assert abs(double_batch_flops - 2 * base_flops) / base_flops < 0.01 - assert abs(double_steps_flops - 2 * base_flops) / base_flops < 0.01 - - -def test_solve_for_batch_size_inverts_flop_calculation(): - """Test that solve_for_batch_size correctly inverts compute_training_flops.""" +def test_flop_solvers_are_consistent(): + """Test that FLOP solvers correctly invert the FLOP calculation.""" model_config = Qwen3Config( max_seq_len=4096, hidden_dim=768, @@ -132,79 +57,34 @@ def test_solve_for_batch_size_inverts_flop_calculation(): ) vocab_size = 128256 seq_len = 4096 - train_steps = 10000 - original_batch_size = 64 - - # Compute FLOPs for known batch size - target_flops = compute_training_flops(model_config, vocab_size, original_batch_size, train_steps, seq_len) - # Solve for batch size given those FLOPs + # Verify solve_for_batch_size inverts compute_training_flops + original_batch = 64 + train_steps = 10000 + target_flops = compute_training_flops(model_config, vocab_size, original_batch, train_steps, seq_len) recovered_batch = solve_for_batch_size(model_config, vocab_size, target_flops, train_steps, seq_len) + assert abs(recovered_batch - original_batch) < 0.01 - # Should recover original batch size (exact float) - assert abs(recovered_batch - original_batch_size) < 0.01 - - -def test_solve_for_train_steps_inverts_flop_calculation(): - """Test that solve_for_train_steps correctly inverts compute_training_flops.""" - model_config = Qwen3Config( - max_seq_len=4096, - hidden_dim=1024, - intermediate_dim=4096, - num_heads=8, - num_kv_heads=8, - num_layers=16, - rope=Llama3RotaryEmbeddingsConfig(), - ) - vocab_size = 128256 - seq_len = 4096 - batch_size = 32 + # Verify solve_for_train_steps inverts compute_training_flops original_steps = 50000 - - # Compute FLOPs for known steps + batch_size = 32 target_flops = compute_training_flops(model_config, vocab_size, batch_size, original_steps, seq_len) - - # Solve for steps given those FLOPs recovered_steps = solve_for_train_steps(model_config, vocab_size, target_flops, batch_size, seq_len) - - # Should recover original steps (exact float) assert abs(recovered_steps - original_steps) < 0.01 -def test_solvers_consistent_with_each_other(): - """Test that solving for batch and then steps gives consistent results.""" - model_config = Qwen3Config( - max_seq_len=4096, - hidden_dim=512, - intermediate_dim=2048, - num_heads=8, - num_kv_heads=8, - num_layers=8, - rope=Llama3RotaryEmbeddingsConfig(), - ) - vocab_size = 128256 - seq_len = 4096 - target_flops = 1e19 - - # Pick arbitrary steps, solve for batch - steps = 20000 - batch = solve_for_batch_size(model_config, vocab_size, target_flops, steps, seq_len) - - # Now with that batch, solve for steps - should get back original - recovered_steps = solve_for_train_steps(model_config, vocab_size, target_flops, round(batch), seq_len) - - # Allow small error from rounding batch to int - relative_error = abs(recovered_steps - steps) / steps - assert relative_error < 0.01 - - # --- Run name parsing tests --- def test_parse_isoflop_run_name(): """Test parsing isoflop run names extracts experiment names.""" - result = parse_isoflop_run_name("isoflop-1e+18-d512-L8-B128-dclm-a1b2c3") - assert result == "dclm" + # New format: isoflop-{budget}-N{params}-B{batch}-{experiment_name} + assert parse_isoflop_run_name("isoflop-1e+18-N1e+08-B128-nemo-wider-depth-adapt") == "nemo-wider-depth-adapt" + assert parse_isoflop_run_name("isoflop-1e+18-N1e+08-B128-dclm-a1b2c3") == "dclm" # hash stripped + + # Legacy format: isoflop-{budget}-d{hidden}-L{layers}-B{batch}-{experiment_name} + assert parse_isoflop_run_name("isoflop-1e+18-d512-L8-B128-dclm-a1b2c3") == "dclm" + assert parse_isoflop_run_name("isoflop-1e+19-d2048-L16-B1024-nemo-wider-depth-adapt") == "nemo-wider-depth-adapt" # Invalid formats return None assert parse_isoflop_run_name("not-a-valid-name") is None @@ -220,8 +100,8 @@ def test_candidate_configs_within_tolerance(): budget = 1e19 flop_tolerance = 0.01 seq_len = DEFAULT_SEQ_LEN + for candidate in candidate_configs(budget, MARIN_TOKENIZER_VOCAB_SIZE, recipe, flop_tolerance=flop_tolerance): - # Build model config from candidate using recipe model_config = recipe.build_model_config(candidate.target_params, MARIN_TOKENIZER_VOCAB_SIZE, seq_len) achieved = compute_training_flops( model_config, @@ -254,212 +134,107 @@ def test_robust_quad_logx_fits_quadratic(): # --- Snapshot test for config generation --- # Snapshot of expected output for generate_isoflop_train_args with budget=3e18 training FLOPs. -# Note: compute_training_flops includes the 3x multiplier for training (forward + backward pass), -# matching how FLOPs are tracked in WandB via Levanter's log_performance_stats. -# -# CandidateConfig is now model-agnostic, containing only: -# - batch_size, train_steps, tokens, target_params, flops_budget EXPECTED_ISOFLOP_CONFIGS_3E18 = [ - { - "batch_size": 32, - "train_steps": 32844, - "flops_budget": 3e18, - }, - { - "batch_size": 16, - "train_steps": 46274, - "flops_budget": 3e18, - }, - { - "batch_size": 16, - "train_steps": 33965, - "flops_budget": 3e18, - }, - { - "batch_size": 8, - "train_steps": 48105, - "flops_budget": 3e18, - }, - { - "batch_size": 8, - "train_steps": 37335, - "flops_budget": 3e18, - }, + {"batch_size": 32, "train_steps": 32844, "flops_budget": 3e18}, + {"batch_size": 16, "train_steps": 46274, "flops_budget": 3e18}, + {"batch_size": 16, "train_steps": 33965, "flops_budget": 3e18}, + {"batch_size": 8, "train_steps": 48105, "flops_budget": 3e18}, + {"batch_size": 8, "train_steps": 37335, "flops_budget": 3e18}, ] def test_generate_isoflop_train_args_snapshot(): """Snapshot test: verify generate_isoflop_train_args produces expected configs. - This test ensures the scaling_laws module produces identical configurations - for reproducible isoflop sweeps. Uses 3e18 training FLOPs budget (which accounts - for the 3x multiplier for forward + backward pass). - - CandidateConfig is now model-agnostic, so we only check the core compute - allocation parameters (batch_size, train_steps, flops_budget). + This ensures reproducibility of the config generation algorithm. """ recipe = ScalingRecipe(name="test-snapshot") - budgets = (3e18,) result = generate_isoflop_train_args( - budgets=budgets, + budgets=(3e18,), experiment_name="test-snapshot", vocab_size=MARIN_TOKENIZER_VOCAB_SIZE, recipe=recipe, ) - assert len(result) == len( - EXPECTED_ISOFLOP_CONFIGS_3E18 - ), f"Expected {len(EXPECTED_ISOFLOP_CONFIGS_3E18)} configs, got {len(result)}" + assert len(result) == len(EXPECTED_ISOFLOP_CONFIGS_3E18) for i, (args, expected) in enumerate(zip(result, EXPECTED_ISOFLOP_CONFIGS_3E18, strict=True)): assert isinstance(args, IsoFlopTrainArgs) c = args.candidate - actual = { - "batch_size": c.batch_size, - "train_steps": c.train_steps, - "flops_budget": c.flops_budget, - } - - for key in expected: - assert ( - actual[key] == expected[key] - ), f"Config {i}: {key} mismatch: expected {expected[key]}, got {actual[key]}" + assert c.batch_size == expected["batch_size"], f"Config {i}: batch_size mismatch" + assert c.train_steps == expected["train_steps"], f"Config {i}: train_steps mismatch" + assert c.flops_budget == expected["flops_budget"], f"Config {i}: flops_budget mismatch" -# --- Metrics transformation tests --- +# --- End-to-end integration test --- -# Sample tracker_metrics.jsonl data extracted from real runs +# Sample tracker_metrics.jsonl data simulating real runs SAMPLE_METRICS_DATA = [ # 1e18 budget - 3 runs with U-shaped loss curve { - "run_path": "gs://marin/checkpoints/isoflop-1e+18-d1024-L11-B8-nemo-wider-depth-adapt", - "config": { - "model": {"hidden_dim": 1024, "num_layers": 11}, - "trainer": {"train_batch_size": 8}, - }, + "run_path": "gs://marin/checkpoints/isoflop-1e+18-d1024-L11-B8-nemo", + "config": {"model": {"hidden_dim": 1024, "num_layers": 11}, "trainer": {"train_batch_size": 8}}, "summary": { - "throughput/total_tokens": 1000000000, - "throughput/total_gflops": 1000000000.0, + "throughput/total_tokens": 1e9, + "throughput/total_gflops": 1e9, "eval/paloma/c4_en/bpb": 1.25, - "parameter_count": 400000000, + "parameter_count": 4e8, }, }, { - "run_path": "gs://marin/checkpoints/isoflop-1e+18-d768-L8-B16-nemo-wider-depth-adapt", - "config": { - "model": {"hidden_dim": 768, "num_layers": 8}, - "trainer": {"train_batch_size": 16}, - }, + "run_path": "gs://marin/checkpoints/isoflop-1e+18-d768-L8-B16-nemo", + "config": {"model": {"hidden_dim": 768, "num_layers": 8}, "trainer": {"train_batch_size": 16}}, "summary": { - "throughput/total_tokens": 2500000000, - "throughput/total_gflops": 1000000000.0, + "throughput/total_tokens": 2.5e9, + "throughput/total_gflops": 1e9, "eval/paloma/c4_en/bpb": 1.12, - "parameter_count": 272513792, + "parameter_count": 2.7e8, }, }, { - "run_path": "gs://marin/checkpoints/isoflop-1e+18-d512-L6-B32-nemo-wider-depth-adapt", - "config": { - "model": {"hidden_dim": 512, "num_layers": 6}, - "trainer": {"train_batch_size": 32}, - }, + "run_path": "gs://marin/checkpoints/isoflop-1e+18-d512-L6-B32-nemo", + "config": {"model": {"hidden_dim": 512, "num_layers": 6}, "trainer": {"train_batch_size": 32}}, "summary": { - "throughput/total_tokens": 5000000000, - "throughput/total_gflops": 1000000000.0, + "throughput/total_tokens": 5e9, + "throughput/total_gflops": 1e9, "eval/paloma/c4_en/bpb": 1.18, - "parameter_count": 156508160, + "parameter_count": 1.5e8, }, }, # 1e19 budget - 3 runs { - "run_path": "gs://marin/checkpoints/isoflop-1e+19-d2048-L21-B16-nemo-wider-depth-adapt", - "config": { - "model": {"hidden_dim": 2048, "num_layers": 21}, - "trainer": {"train_batch_size": 16}, - }, + "run_path": "gs://marin/checkpoints/isoflop-1e+19-d2048-L21-B16-nemo", + "config": {"model": {"hidden_dim": 2048, "num_layers": 21}, "trainer": {"train_batch_size": 16}}, "summary": { - "throughput/total_tokens": 3000000000, - "throughput/total_gflops": 10000000000.0, + "throughput/total_tokens": 3e9, + "throughput/total_gflops": 1e10, "eval/paloma/c4_en/bpb": 1.05, - "parameter_count": 1800000000, + "parameter_count": 1.8e9, }, }, { - "run_path": "gs://marin/checkpoints/isoflop-1e+19-d1536-L16-B32-nemo-wider-depth-adapt", - "config": { - "model": {"hidden_dim": 1536, "num_layers": 16}, - "trainer": {"train_batch_size": 32}, - }, + "run_path": "gs://marin/checkpoints/isoflop-1e+19-d1536-L16-B32-nemo", + "config": {"model": {"hidden_dim": 1536, "num_layers": 16}, "trainer": {"train_batch_size": 32}}, "summary": { - "throughput/total_tokens": 8000000000, - "throughput/total_gflops": 10000000000.0, + "throughput/total_tokens": 8e9, + "throughput/total_gflops": 1e10, "eval/paloma/c4_en/bpb": 0.98, - "parameter_count": 998036992, + "parameter_count": 1e9, }, }, { - "run_path": "gs://marin/checkpoints/isoflop-1e+19-d1024-L11-B64-nemo-wider-depth-adapt", - "config": { - "model": {"hidden_dim": 1024, "num_layers": 11}, - "trainer": {"train_batch_size": 64}, - }, + "run_path": "gs://marin/checkpoints/isoflop-1e+19-d1024-L11-B64-nemo", + "config": {"model": {"hidden_dim": 1024, "num_layers": 11}, "trainer": {"train_batch_size": 64}}, "summary": { - "throughput/total_tokens": 20000000000, - "throughput/total_gflops": 10000000000.0, + "throughput/total_tokens": 2e10, + "throughput/total_gflops": 1e10, "eval/paloma/c4_en/bpb": 1.02, - "parameter_count": 400000000, + "parameter_count": 4e8, }, }, ] -def test_transform_metrics_for_isoflop(): - """Test transformation of raw metrics data to isoflop analysis format.""" - raw_df = pd.DataFrame(SAMPLE_METRICS_DATA) - metric_key = "eval/paloma/c4_en/bpb" - - result = transform_metrics_for_isoflop(raw_df, metric_key) - - assert len(result) == 6 # 3 runs at 1e18 + 3 runs at 1e19 - - # Verify specific values from first row (d1024/L11) - row0 = result.iloc[0] - assert row0["tokens"] == 1000000000 - assert row0["loss"] == 1.25 - assert row0["hidden_dim"] == 1024 - assert row0["num_layers"] == 11 - assert row0["batch_size"] == 8 - assert row0["flops"] == 1e18 - assert row0["params"] == 400000000 - - -def test_transform_metrics_filters_low_flops(): - """Test that runs with < 1e18 FLOPs are filtered out.""" - raw_df = pd.DataFrame( - [ - { - "run_path": "gs://marin/checkpoints/small-run", - "config": { - "model": {"hidden_dim": 256, "num_layers": 4}, - "trainer": {"train_batch_size": 8}, - }, - "summary": { - "throughput/total_tokens": 1e7, - "throughput/total_gflops": 1e6, # Only 1e15 FLOPs - "eval/paloma/c4_en/bpb": 3.0, - "parameter_count": 1e7, - }, - } - ] - ) - - result = transform_metrics_for_isoflop(raw_df, "eval/paloma/c4_en/bpb") - assert len(result) == 0 - - -# --- End-to-end integration test --- - - def test_end_to_end_analysis_pipeline(): """Integration test: transform metrics and fit scaling laws end-to-end. @@ -477,17 +252,15 @@ def test_end_to_end_analysis_pipeline(): # Should find two minima (one per budget: 1e18 and 1e19) assert len(fit_result.minima_records) == 2 - flops_budgets = {rec.flops for rec in fit_result.minima_records} - assert flops_budgets == {1e18, 1e19} + assert {rec.flops for rec in fit_result.minima_records} == {1e18, 1e19} # Verify fitted minima are near expected optimal points - # Curve fitting interpolates to find analytical minimum of fitted quadratic minima_by_flops = {rec.flops: rec for rec in fit_result.minima_records} - # At 1e18: raw data optimal at 2.5B (loss=1.12), fitted minimum ~2.6B + # At 1e18: raw data optimal at 2.5B tokens (loss=1.12) assert abs(minima_by_flops[1e18].optimal_tokens - 2.6e9) < 0.2e9 assert abs(minima_by_flops[1e18].loss_at_optimal - 1.12) < 0.01 - # At 1e19: raw data optimal at 8B (loss=0.98), fitted minimum ~8.8B + # At 1e19: raw data optimal at 8B tokens (loss=0.98) assert abs(minima_by_flops[1e19].optimal_tokens - 8.8e9) < 0.2e9 assert abs(minima_by_flops[1e19].loss_at_optimal - 0.98) < 0.01 From ff33db61e1dc7fe69e84a06f4f06347bb1ed6b96 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Wed, 7 Jan 2026 22:12:26 -0800 Subject: [PATCH 53/79] Lint --- experiments/exp1603_subgroup_evals.py | 15 +++------------ lib/marin/src/marin/scaling_laws/recipe.py | 3 +-- lib/marin/src/marin/scaling_laws/tpu_utils.py | 1 + tests/test_scaling_laws.py | 1 - 4 files changed, 5 insertions(+), 15 deletions(-) diff --git a/experiments/exp1603_subgroup_evals.py b/experiments/exp1603_subgroup_evals.py index 610c9d9e36..30e63b0fd0 100644 --- a/experiments/exp1603_subgroup_evals.py +++ b/experiments/exp1603_subgroup_evals.py @@ -46,10 +46,7 @@ def create_eval_steps() -> list: dist_eval = distributional_eval_sets(llama3_tokenizer) for model, candidate in list(zip(*MARIN_SCALING_SUITES["nemotron"], strict=False)): total_tokens = candidate.batch_size * candidate.train_steps * 4096 - name = ( - f"marin-nemo-{candidate.flops_budget:.0e}C-{total_tokens}T-" - f"N{candidate.target_params:.0e}" - ) + name = f"marin-nemo-{candidate.flops_budget:.0e}C-{total_tokens}T-" f"N{candidate.target_params:.0e}" step = evaluate_levanter_lm_evaluation_harness( model_name=name, @@ -73,10 +70,7 @@ def create_eval_steps() -> list: for model, candidate in list(zip(*MARIN_SCALING_SUITES["common_pile"], strict=False)): total_tokens = candidate.batch_size * candidate.train_steps * 4096 - name = ( - f"marin-comma-{candidate.flops_budget:.0e}C-{total_tokens}T-" - f"N{candidate.target_params:.0e}" - ) + name = f"marin-comma-{candidate.flops_budget:.0e}C-{total_tokens}T-" f"N{candidate.target_params:.0e}" step = evaluate_levanter_lm_evaluation_harness( model_name=name, @@ -100,10 +94,7 @@ def create_eval_steps() -> list: for model, candidate in list(zip(*MARIN_SCALING_SUITES["dclm-default"], strict=False)): total_tokens = candidate.batch_size * candidate.train_steps * 4096 - name = ( - f"marin-dclm-{candidate.flops_budget:.0e}C-{total_tokens}T-" - f"N{candidate.target_params:.0e}" - ) + name = f"marin-dclm-{candidate.flops_budget:.0e}C-{total_tokens}T-" f"N{candidate.target_params:.0e}" step = evaluate_levanter_lm_evaluation_harness( model_name=name, diff --git a/lib/marin/src/marin/scaling_laws/recipe.py b/lib/marin/src/marin/scaling_laws/recipe.py index 810e865436..e973583859 100644 --- a/lib/marin/src/marin/scaling_laws/recipe.py +++ b/lib/marin/src/marin/scaling_laws/recipe.py @@ -405,8 +405,7 @@ def generate_isoflop_train_args( for budget in budgets: for candidate in self.candidate_configs(budget, vocab_size, seq_len, steps_per_run, flop_tolerance): run_name = ( - f"isoflop-{budget:.0e}-N{candidate.target_params:.0e}-" - f"B{candidate.batch_size}-{experiment_name}" + f"isoflop-{budget:.0e}-N{candidate.target_params:.0e}-" f"B{candidate.batch_size}-{experiment_name}" ) tags = ( diff --git a/lib/marin/src/marin/scaling_laws/tpu_utils.py b/lib/marin/src/marin/scaling_laws/tpu_utils.py index 466cf6af2d..165df2ab0e 100644 --- a/lib/marin/src/marin/scaling_laws/tpu_utils.py +++ b/lib/marin/src/marin/scaling_laws/tpu_utils.py @@ -103,6 +103,7 @@ def pick_v5p_type( """ if recipe is None: from marin.scaling_laws.recipe import ScalingRecipe + recipe = ScalingRecipe(name="default") hidden_size = recipe.hidden_size_for_params(candidate.target_params, vocab_size) diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py index d4315602da..90c865b7af 100644 --- a/tests/test_scaling_laws.py +++ b/tests/test_scaling_laws.py @@ -40,7 +40,6 @@ transform_metrics_for_isoflop, ) - # --- FLOP computation tests --- From 19e68478a0c6093cfb8adf7af02e4408651c829b Mon Sep 17 00:00:00 2001 From: Helw150 Date: Thu, 8 Jan 2026 10:46:03 -0800 Subject: [PATCH 54/79] Keep Lib Opinion Clean Even More --- experiments/isoflop_sweep.py | 319 ++++++++++++- .../marin/scaling_laws/isoflop_analysis.py | 5 +- lib/marin/src/marin/scaling_laws/recipe.py | 436 +----------------- lib/marin/src/marin/scaling_laws/tpu_utils.py | 9 +- tests/test_scaling_laws.py | 8 +- 5 files changed, 343 insertions(+), 434 deletions(-) diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index 8e621d31d1..826b13f48d 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -18,9 +18,17 @@ sizes while keeping the total training FLOPs roughly constant. """ -from dataclasses import replace +import math +import os +from collections.abc import Iterator, Sequence +from dataclasses import dataclass, replace from levanter.data.text import LMMixtureDatasetConfig +from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig +from levanter.models.llama import LlamaConfig +from levanter.models.qwen import Qwen3Config +from levanter.optim.cautious import CautiousConfig +from levanter.optim.config import OptimizerConfig from experiments.evals.evals import default_eval from experiments.evals.task_configs import EvalTaskConfig @@ -35,13 +43,312 @@ from marin.processing.tokenize import get_vocab_size_for_tokenizer, lm_mixture_data_config from marin.scaling_laws import ( DEFAULT_BUDGETS, + DEFAULT_FLOP_TOLERANCE, + DEFAULT_SEQ_LEN, + DEFAULT_STEPS_PER_RUN, CandidateConfig, + IsoFlopTrainArgs, ScalingRecipe, generate_isoflop_train_args, pick_v5p_type, + solve_for_batch_size, + solve_for_train_steps, ) -MARIN_2025_RECIPE = ScalingRecipe(name="marin-2025") + +def _round_to_power_of_two(x: float) -> int: + """Round x UP to the nearest power of 2.""" + if x <= 1: + return 1 + return 2 ** math.ceil(math.log2(x)) + + +@dataclass(frozen=True) +class Marin2025Recipe: + """Marin 2025 scaling recipe with all hyperparameters and formulas. + + This recipe implements all the Marin-specific decisions for scaling experiments. + """ + + name: str = "marin-2025" + + # --- Learning rate scaling --- + # lr = lr_constant * sqrt(batch_size) / hidden_dim + lr_constant: float = 0.33 + + # --- Beta2 scaling for Adam --- + # beta2 = beta2_base ** (batch_size / beta2_batch_divisor) + beta2_base: float = 0.98 + beta2_batch_divisor: float = 128 + + # --- Optimizer hyperparameters --- + weight_decay: float = 0.1 + min_lr_ratio: float = 0.0 + warmup: float = 0.1 + beta1: float = 0.95 + epsilon: float = 1e-15 + max_grad_norm: float = 1.0 + lr_schedule: str = "linear" + decay: float = 0.2 + + # --- Architecture ratios --- + mlp_ratio: int = 4 + hidden_head_ratio: int = 128 + + # --- Architecture formula for depth-to-width scaling --- + base_hidden_layer_ratio: int = 64 + layer_scaling_factor: float = 4.0 + layer_formula_offset: int = 9 + + # --- Constraints --- + max_learning_rate: float = 0.01 + min_batch_size: int = 8 + + # --- Search bounds for isoflop sweeps --- + min_hidden_pow: int = 9 + max_hidden_pow: int = 12 + small_budget_step_size: int = 128 + large_budget_step_size: int = 256 + budget_step_threshold: float = 9e18 + + def _compute_learning_rate(self, batch_size: int, hidden_dim: int) -> float: + """Compute learning rate from batch size and hidden dim.""" + return (self.lr_constant * math.sqrt(batch_size)) / hidden_dim + + def _compute_beta2(self, batch_size: int) -> float: + """Compute beta2 from batch size.""" + return self.beta2_base ** (batch_size / self.beta2_batch_divisor) + + def compute_num_layers(self, hidden_size: int) -> int: + """Compute number of layers from hidden size using the depth-width formula.""" + hs_pow = math.log2(hidden_size) + return round( + hidden_size + / (self.base_hidden_layer_ratio + (hs_pow * self.layer_scaling_factor) - self.layer_formula_offset) + ) + + def _get_step_size(self, budget: float) -> int: + """Get hidden_size search step size based on budget.""" + if budget > self.budget_step_threshold: + return self.large_budget_step_size + return self.small_budget_step_size + + def _compute_params_for_hidden_size(self, hidden_size: int, vocab_size: int) -> int: + """Compute approximate parameter count for a given hidden size.""" + num_layers = self.compute_num_layers(hidden_size) + intermediate_dim = hidden_size * self.mlp_ratio + n_heads = max(1, hidden_size // self.hidden_head_ratio) + head_size = hidden_size // n_heads + + embed_params = vocab_size * hidden_size * 2 + q_proj = hidden_size * head_size * n_heads + kv_proj = 2 * hidden_size * head_size * n_heads + o_proj = head_size * n_heads * hidden_size + attn_params = q_proj + kv_proj + o_proj + mlp_params = 3 * hidden_size * intermediate_dim + norm_params = 2 * hidden_size + layer_params = attn_params + mlp_params + norm_params + total_layer_params = num_layers * layer_params + final_norm = hidden_size + + return embed_params + total_layer_params + final_norm + + def hidden_size_for_params(self, target_params: int, vocab_size: int) -> int: + """Find the hidden size that gives approximately target_params.""" + min_hidden = 2**self.min_hidden_pow + max_hidden = 2**self.max_hidden_pow + + best_hidden = min_hidden + best_diff = abs(self._compute_params_for_hidden_size(min_hidden, vocab_size) - target_params) + + for hidden_size in range(min_hidden, max_hidden + 1, 64): + params = self._compute_params_for_hidden_size(hidden_size, vocab_size) + diff = abs(params - target_params) + if diff < best_diff: + best_diff = diff + best_hidden = hidden_size + + return best_hidden + + def build_model_config(self, target_params: int, vocab_size: int, seq_len: int = DEFAULT_SEQ_LEN) -> LlamaConfig: + """Build a Qwen3 model config for a target parameter count.""" + hidden_size = self.hidden_size_for_params(target_params, vocab_size) + num_layers = self.compute_num_layers(hidden_size) + intermediate_dim = hidden_size * self.mlp_ratio + n_heads = max(1, hidden_size // self.hidden_head_ratio) + + return Qwen3Config( + hidden_dim=hidden_size, + intermediate_dim=intermediate_dim, + num_layers=num_layers, + num_heads=n_heads, + num_kv_heads=n_heads, + max_seq_len=seq_len, + rope=Llama3RotaryEmbeddingsConfig(), + ) + + def _build_model_config_from_hidden_size(self, hidden_size: int, seq_len: int = DEFAULT_SEQ_LEN) -> LlamaConfig: + """Build model config from hidden_size directly.""" + num_layers = self.compute_num_layers(hidden_size) + intermediate_dim = hidden_size * self.mlp_ratio + n_heads = max(1, hidden_size // self.hidden_head_ratio) + + return Qwen3Config( + hidden_dim=hidden_size, + intermediate_dim=intermediate_dim, + num_layers=num_layers, + num_heads=n_heads, + num_kv_heads=n_heads, + max_seq_len=seq_len, + rope=Llama3RotaryEmbeddingsConfig(), + ) + + def build_optimizer_config(self, candidate: CandidateConfig, vocab_size: int) -> OptimizerConfig: + """Build optimizer config for a candidate.""" + hidden_size = self.hidden_size_for_params(candidate.target_params, vocab_size) + learning_rate = self._compute_learning_rate(candidate.batch_size, hidden_size) + beta2 = self._compute_beta2(candidate.batch_size) + + return CautiousConfig( + learning_rate=learning_rate, + weight_decay=self.weight_decay, + min_lr_ratio=self.min_lr_ratio, + warmup=self.warmup, + beta1=self.beta1, + beta2=beta2, + epsilon=self.epsilon, + max_grad_norm=self.max_grad_norm, + adamc_weight_decay=True, + lr_schedule=self.lr_schedule, + decay=self.decay, + ) + + def candidate_configs( + self, + budget: float, + vocab_size: int, + seq_len: int = DEFAULT_SEQ_LEN, + steps_per_run: int = DEFAULT_STEPS_PER_RUN, + flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, + ) -> Iterator[CandidateConfig]: + """Yield candidate configurations within the FLOP budget.""" + step_size = self._get_step_size(budget) + min_hidden = 2**self.min_hidden_pow + max_hidden = 2**self.max_hidden_pow + + for hidden_size in range(min_hidden, max_hidden + 1, step_size): + model_config = self._build_model_config_from_hidden_size(hidden_size, seq_len) + + batch_exact = solve_for_batch_size(model_config, vocab_size, budget, steps_per_run, seq_len) + batch_size = _round_to_power_of_two(batch_exact) + + lr = self._compute_learning_rate(batch_size, hidden_size) + while lr > self.max_learning_rate: + batch_size //= 2 + lr = self._compute_learning_rate(batch_size, hidden_size) + + if batch_size < self.min_batch_size: + continue + + train_steps = round(solve_for_train_steps(model_config, vocab_size, budget, batch_size, seq_len)) + + achieved_flops = 3 * model_config.flops_per_token(vocab_size, seq_len) * batch_size * train_steps * seq_len + if abs(achieved_flops - budget) / budget > flop_tolerance: + continue + + tokens = batch_size * train_steps * seq_len + target_params = self._compute_params_for_hidden_size(hidden_size, vocab_size) + + yield CandidateConfig( + batch_size=batch_size, + train_steps=train_steps, + tokens=tokens, + target_params=target_params, + flops_budget=budget, + ) + + def generate_isoflop_train_args( + self, + budgets: Sequence[float], + experiment_name: str, + vocab_size: int, + seq_len: int = DEFAULT_SEQ_LEN, + steps_per_run: int = DEFAULT_STEPS_PER_RUN, + flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, + ) -> list[IsoFlopTrainArgs]: + """Generate training arguments for each candidate in an isoflop sweep.""" + results: list[IsoFlopTrainArgs] = [] + + for budget in budgets: + for candidate in self.candidate_configs(budget, vocab_size, seq_len, steps_per_run, flop_tolerance): + run_name = ( + f"isoflop-{budget:.0e}-N{candidate.target_params:.0e}-" f"B{candidate.batch_size}-{experiment_name}" + ) + + tags = ( + f"FLOPs={budget:.1e}", + f"N={candidate.target_params:.1e}", + f"B={candidate.batch_size}", + f"steps={candidate.train_steps}", + f"tokens={candidate.tokens:.1e}", + ) + + output_path = os.path.join("checkpoints", "isoflop", run_name) + + results.append( + IsoFlopTrainArgs( + candidate=candidate, + run_name=run_name, + tags=tags, + output_path=output_path, + ) + ) + + return results + + def predict_optimal_config( + self, + scaling_fits: dict[str, tuple[float, float]], + target_flops: float, + label: str, + vocab_size: int, + seq_len: int = DEFAULT_SEQ_LEN, + steps_per_run: int = DEFAULT_STEPS_PER_RUN, + flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, + ) -> CandidateConfig | None: + """Predict optimal training config for a target compute budget using fitted scaling laws.""" + import logging + + logger = logging.getLogger(__name__) + + if label not in scaling_fits: + logger.warning(f"Label '{label}' not found in scaling fits") + return None + + alpha, A = scaling_fits[label] + optimal_tokens = A * (target_flops**alpha) + + logger.info(f"Predicted optimal tokens for {target_flops:.2e} FLOPs: {optimal_tokens:.2e}") + + candidates = list(self.candidate_configs(target_flops, vocab_size, seq_len, steps_per_run, flop_tolerance)) + + if not candidates: + logger.warning(f"No valid candidates found for budget {target_flops:.2e}") + return None + + best = min(candidates, key=lambda c: c.tokens - optimal_tokens if c.tokens >= optimal_tokens else float("inf")) + if best.tokens < optimal_tokens: + best = max(candidates, key=lambda c: c.tokens) + + logger.info( + f"Selected config: N={best.target_params:.2e}, " + f"B={best.batch_size}, tokens={best.tokens:.2e} (optimal: {optimal_tokens:.2e})" + ) + + return best + + +MARIN_2025_RECIPE = Marin2025Recipe() """Default Marin scaling recipe.""" @@ -82,16 +389,12 @@ def create_isoflop_sweep_steps( recipe=recipe, ) - # Base config for training runs + # Base config for training runs (values overridden per-candidate via optimizer_config) base_train_config = SimpleTrainConfig( resources=ResourceConfig.with_tpu("v5p-8"), train_batch_size=1, num_train_steps=50_000, - learning_rate=1.0, # Placeholder, will be overridden - weight_decay=recipe.weight_decay, - min_lr_ratio=recipe.min_lr_ratio, - lr_schedule=recipe.lr_schedule, - decay=recipe.decay, + learning_rate=1.0, # Overridden via optimizer_config ) train_steps: list[ExecutorStep] = [] diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index 260c16a516..d63741a5ee 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -451,8 +451,9 @@ def generate_isoflop_train_args( List of IsoFlopTrainArgs, one per candidate config across all budgets. Example: - >>> from marin.scaling_laws import generate_isoflop_train_args, DEFAULT_BUDGETS, ScalingRecipe - >>> recipe = ScalingRecipe(name="my-recipe") + >>> from marin.scaling_laws import generate_isoflop_train_args, DEFAULT_BUDGETS + >>> # Use a concrete recipe implementation (e.g., from experiments/isoflop_sweep.py) + >>> # recipe = Marin2025Recipe() >>> train_args = generate_isoflop_train_args( ... budgets=DEFAULT_BUDGETS, ... experiment_name="my-experiment", diff --git a/lib/marin/src/marin/scaling_laws/recipe.py b/lib/marin/src/marin/scaling_laws/recipe.py index e973583859..e88971aa24 100644 --- a/lib/marin/src/marin/scaling_laws/recipe.py +++ b/lib/marin/src/marin/scaling_laws/recipe.py @@ -14,295 +14,54 @@ """Scaling recipes: model-specific hyperparameter bundles for scaling law experiments. -A ScalingRecipe encapsulates model-specific decisions for scaling experiments: +A ScalingRecipe defines the interface for scaling experiments. Concrete implementations +provide model-specific decisions for: - Architecture formula (how to compute architecture from target param count) -- Architecture ratios (MLP width, head size) - Model config building (returns LlamaConfig or subclass) -- Learning rate and optimizer hyperparameters -- Search bounds and constraints for isoflop sweeps +- Optimizer config building - Candidate generation for isoflop sweeps """ -import math -import os from collections.abc import Iterator, Sequence -from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Protocol -from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig from levanter.models.llama import LlamaConfig -from levanter.models.qwen import Qwen3Config -from levanter.optim.cautious import CautiousConfig from levanter.optim.config import OptimizerConfig if TYPE_CHECKING: from marin.scaling_laws.isoflop_analysis import CandidateConfig, IsoFlopTrainArgs -# TODO: LlamaConfig is used as our "abstract" model config base class. -# All model configs we use (Qwen3Config, etc.) inherit from LlamaConfig -# and provide flops_per_token() for FLOP calculations. - # Default constants DEFAULT_SEQ_LEN = 4096 DEFAULT_STEPS_PER_RUN = 2**16 # Reference step count for hyperparameter tuning DEFAULT_FLOP_TOLERANCE = 0.01 # Relative error tolerance for FLOP budget -DEFAULT_TOKENIZER = "stanford-crfm/marin-tokenizer" - - -def _round_to_power_of_two(x: float) -> int: - """Round x UP to the nearest power of 2.""" - if x <= 1: - return 1 - return 2 ** math.ceil(math.log2(x)) -@dataclass(frozen=True) -class ScalingRecipe: - """A named set of hyperparameters for scaling law experiments. +class ScalingRecipe(Protocol): + """Protocol defining the interface for scaling law recipes. - The recipe encapsulates ALL model-specific decisions: - - Architecture formula (num_layers from hidden_size) - - Architecture ratios (MLP width, head size) - - Learning rate scaling formula - - Beta2 scaling formula (for Adam) - - Optimizer hyperparameters (weight decay, warmup, etc.) - - Search bounds and constraints for isoflop sweeps - - Candidate generation + Concrete implementations (e.g., Marin2025Recipe) should implement all methods + with their specific hyperparameters and formulas. """ name: str """Name identifying this recipe (e.g., 'marin-2025').""" - # --- Learning rate scaling --- - # lr = lr_constant * sqrt(batch_size) / hidden_dim - lr_constant: float = 0.33 - """Constant for learning rate calculation.""" - - # --- Beta2 scaling for Adam --- - # beta2 = beta2_base ** (batch_size / beta2_batch_divisor) - # Reference: https://arxiv.org/pdf/2507.07101 - beta2_base: float = 0.98 - """Base for beta2 exponential scaling.""" - - beta2_batch_divisor: float = 128 - """Divisor for beta2 batch size scaling.""" - - # --- Optimizer hyperparameters --- - weight_decay: float = 0.1 - min_lr_ratio: float = 0.0 - warmup: float = 0.1 - beta1: float = 0.95 - epsilon: float = 1e-15 - max_grad_norm: float = 1.0 - lr_schedule: str = "linear" - decay: float = 0.2 - - # --- Architecture ratios --- - mlp_ratio: int = 4 - """MLP intermediate_dim = hidden_dim * mlp_ratio.""" - - hidden_head_ratio: int = 128 - """num_heads = hidden_dim / hidden_head_ratio.""" - - # --- Architecture formula for depth-to-width scaling --- - # num_layers = round( - # hidden_size - # / ( - # base_hidden_layer_ratio - # + (log2(hidden_size) * layer_scaling_factor) - # - layer_formula_offset - # ) - # ) - base_hidden_layer_ratio: int = 64 - """Base divisor for depth-width formula.""" - - layer_scaling_factor: float = 4.0 - """Multiplier for log2(hidden_size) in depth formula.""" - - layer_formula_offset: int = 9 - """Offset (typically min_hidden_pow) in depth formula.""" - - # --- Constraints --- - max_learning_rate: float = 0.01 - """Maximum allowed learning rate (configs with higher LR are rejected).""" - - min_batch_size: int = 8 - """Minimum allowed batch size (configs with smaller batch are rejected).""" - - # --- Search bounds for isoflop sweeps --- - min_hidden_pow: int = 9 - """Minimum hidden_size as power of 2 (2^9 = 512).""" - - max_hidden_pow: int = 12 - """Maximum hidden_size as power of 2 (2^12 = 4096).""" - - small_budget_step_size: int = 128 - """Step size for hidden_size search at smaller budgets.""" - - large_budget_step_size: int = 256 - """Step size for hidden_size search at larger budgets.""" - - budget_step_threshold: float = 9e18 - """Budget threshold for switching step sizes.""" - - # --- Hyperparameter formulas --- - - def compute_learning_rate(self, batch_size: int, hidden_dim: int) -> float: - """Compute learning rate from batch size and hidden dim.""" - return (self.lr_constant * math.sqrt(batch_size)) / hidden_dim - - def compute_beta2(self, batch_size: int) -> float: - """Compute beta2 from batch size.""" - return self.beta2_base ** (batch_size / self.beta2_batch_divisor) - def compute_num_layers(self, hidden_size: int) -> int: - """Compute number of layers from hidden size using the depth-width formula.""" - hs_pow = math.log2(hidden_size) - return round( - hidden_size - / (self.base_hidden_layer_ratio + (hs_pow * self.layer_scaling_factor) - self.layer_formula_offset) - ) - - def get_step_size(self, budget: float) -> int: - """Get hidden_size search step size based on budget.""" - if budget > self.budget_step_threshold: - return self.large_budget_step_size - return self.small_budget_step_size - - # --- Parameter count estimation --- - - def compute_params_for_hidden_size(self, hidden_size: int, vocab_size: int) -> int: - """Compute approximate parameter count for a given hidden size. - - This uses the standard transformer parameter formula for Llama/Qwen architectures. - """ - num_layers = self.compute_num_layers(hidden_size) - intermediate_dim = hidden_size * self.mlp_ratio - n_heads = max(1, hidden_size // self.hidden_head_ratio) - head_size = hidden_size // n_heads - - # Embeddings - embed_params = vocab_size * hidden_size * 2 # input + output embeddings - - # Per-layer params: attention + mlp + layer norms - q_proj = hidden_size * head_size * n_heads - kv_proj = 2 * hidden_size * head_size * n_heads # K and V - o_proj = head_size * n_heads * hidden_size - attn_params = q_proj + kv_proj + o_proj - - mlp_params = 3 * hidden_size * intermediate_dim # gate, up, down - norm_params = 2 * hidden_size # 2 layer norms per layer - - layer_params = attn_params + mlp_params + norm_params - total_layer_params = num_layers * layer_params - - # Final layer norm - final_norm = hidden_size - - return embed_params + total_layer_params + final_norm + """Compute number of layers from hidden size using the recipe's depth-width formula.""" + ... def hidden_size_for_params(self, target_params: int, vocab_size: int) -> int: - """Find the hidden size that gives approximately target_params. - - Uses binary search over valid hidden sizes. - """ - min_hidden = 2**self.min_hidden_pow - max_hidden = 2**self.max_hidden_pow - - best_hidden = min_hidden - best_diff = abs(self.compute_params_for_hidden_size(min_hidden, vocab_size) - target_params) - - # Search in steps of 64 for efficiency - for hidden_size in range(min_hidden, max_hidden + 1, 64): - params = self.compute_params_for_hidden_size(hidden_size, vocab_size) - diff = abs(params - target_params) - if diff < best_diff: - best_diff = diff - best_hidden = hidden_size - - return best_hidden - - # --- Model config building --- + """Find the hidden size that gives approximately target_params.""" + ... def build_model_config(self, target_params: int, vocab_size: int, seq_len: int = DEFAULT_SEQ_LEN) -> LlamaConfig: - """Build a model config for a target parameter count. - - The recipe determines the architecture (hidden_size, num_layers, etc.) - that achieves approximately target_params parameters. - - Args: - target_params: Target parameter count. - vocab_size: Vocabulary size. - seq_len: Maximum sequence length. - - Returns: - A LlamaConfig (or subclass) with architecture determined by this recipe. - """ - hidden_size = self.hidden_size_for_params(target_params, vocab_size) - num_layers = self.compute_num_layers(hidden_size) - intermediate_dim = hidden_size * self.mlp_ratio - n_heads = max(1, hidden_size // self.hidden_head_ratio) - - return Qwen3Config( - hidden_dim=hidden_size, - intermediate_dim=intermediate_dim, - num_layers=num_layers, - num_heads=n_heads, - num_kv_heads=n_heads, - max_seq_len=seq_len, - rope=Llama3RotaryEmbeddingsConfig(), - ) - - def _build_model_config_from_hidden_size(self, hidden_size: int, seq_len: int = DEFAULT_SEQ_LEN) -> LlamaConfig: - """Internal: build model config from hidden_size directly. - - Used during candidate generation when we're iterating over hidden sizes. - """ - num_layers = self.compute_num_layers(hidden_size) - intermediate_dim = hidden_size * self.mlp_ratio - n_heads = max(1, hidden_size // self.hidden_head_ratio) - - return Qwen3Config( - hidden_dim=hidden_size, - intermediate_dim=intermediate_dim, - num_layers=num_layers, - num_heads=n_heads, - num_kv_heads=n_heads, - max_seq_len=seq_len, - rope=Llama3RotaryEmbeddingsConfig(), - ) + """Build a model config for a target parameter count.""" + ... def build_optimizer_config(self, candidate: "CandidateConfig", vocab_size: int) -> OptimizerConfig: - """Build optimizer config for a candidate. - - Computes learning rate and beta2 from the candidate's batch_size and target_params. - - Args: - candidate: Model-agnostic candidate config. - vocab_size: Vocabulary size (needed to determine hidden_size). - - Returns: - An OptimizerConfig with settings from this recipe. - """ - hidden_size = self.hidden_size_for_params(candidate.target_params, vocab_size) - learning_rate = self.compute_learning_rate(candidate.batch_size, hidden_size) - beta2 = self.compute_beta2(candidate.batch_size) - - return CautiousConfig( - learning_rate=learning_rate, - weight_decay=self.weight_decay, - min_lr_ratio=self.min_lr_ratio, - warmup=self.warmup, - beta1=self.beta1, - beta2=beta2, - epsilon=self.epsilon, - max_grad_norm=self.max_grad_norm, - adamc_weight_decay=True, - lr_schedule=self.lr_schedule, - decay=self.decay, - ) - - # --- Candidate generation (model-specific search) --- + """Build optimizer config for a candidate.""" + ... def candidate_configs( self, @@ -312,65 +71,8 @@ def candidate_configs( steps_per_run: int = DEFAULT_STEPS_PER_RUN, flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, ) -> "Iterator[CandidateConfig]": - """Yield model-agnostic candidate configurations within the FLOP budget. - - This method encapsulates the model-specific search logic internally but - returns model-agnostic CandidateConfig objects containing only: - batch_size, train_steps, tokens, target_params, flops_budget. - - The caller uses recipe.build_model_config() and recipe.build_optimizer_config() - to convert these to model-specific configs. - - Args: - budget: Target FLOP budget. - vocab_size: Vocabulary size for the tokenizer. - seq_len: Sequence length for training. - steps_per_run: Reference step count for FLOP budget calculation. - flop_tolerance: Tolerance for matching FLOP budget (relative error). - - Yields: - Model-agnostic CandidateConfig objects for each valid configuration. - """ - # Import here to avoid circular dependency - from marin.scaling_laws.isoflop_analysis import CandidateConfig, solve_for_batch_size, solve_for_train_steps - - step_size = self.get_step_size(budget) - min_hidden = 2**self.min_hidden_pow - max_hidden = 2**self.max_hidden_pow - - for hidden_size in range(min_hidden, max_hidden + 1, step_size): - model_config = self._build_model_config_from_hidden_size(hidden_size, seq_len) - - batch_exact = solve_for_batch_size(model_config, vocab_size, budget, steps_per_run, seq_len) - batch_size = _round_to_power_of_two(batch_exact) - - # Apply LR constraint - lr = self.compute_learning_rate(batch_size, hidden_size) - while lr > self.max_learning_rate: - batch_size //= 2 - lr = self.compute_learning_rate(batch_size, hidden_size) - - if batch_size < self.min_batch_size: - continue - - train_steps = round(solve_for_train_steps(model_config, vocab_size, budget, batch_size, seq_len)) - - # Verify we hit the budget within tolerance - # Training FLOPs = 3 * flops_per_token * batch * steps * seq_len - achieved_flops = 3 * model_config.flops_per_token(vocab_size, seq_len) * batch_size * train_steps * seq_len - if abs(achieved_flops - budget) / budget > flop_tolerance: - continue - - tokens = batch_size * train_steps * seq_len - target_params = self.compute_params_for_hidden_size(hidden_size, vocab_size) - - yield CandidateConfig( - batch_size=batch_size, - train_steps=train_steps, - tokens=tokens, - target_params=target_params, - flops_budget=budget, - ) + """Yield candidate configurations within the FLOP budget.""" + ... def generate_isoflop_train_args( self, @@ -381,54 +83,8 @@ def generate_isoflop_train_args( steps_per_run: int = DEFAULT_STEPS_PER_RUN, flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, ) -> "list[IsoFlopTrainArgs]": - """Generate model-agnostic training arguments for each candidate in an isoflop sweep. - - Returns IsoFlopTrainArgs containing the model-agnostic CandidateConfig plus - naming information. The caller uses recipe.build_model_config() and - recipe.build_optimizer_config() to get model-specific configs. - - Args: - budgets: Sequence of FLOP budgets to generate configs for. - experiment_name: Name suffix for run names (e.g., 'nemo', 'dclm'). - vocab_size: Vocabulary size for the tokenizer. - seq_len: Sequence length for training. - steps_per_run: Reference step count for FLOP budget calculation. - flop_tolerance: Tolerance for matching FLOP budget. - - Returns: - List of IsoFlopTrainArgs, one per candidate config across all budgets. - """ - from marin.scaling_laws.isoflop_analysis import IsoFlopTrainArgs - - results: list[IsoFlopTrainArgs] = [] - - for budget in budgets: - for candidate in self.candidate_configs(budget, vocab_size, seq_len, steps_per_run, flop_tolerance): - run_name = ( - f"isoflop-{budget:.0e}-N{candidate.target_params:.0e}-" f"B{candidate.batch_size}-{experiment_name}" - ) - - tags = ( - f"FLOPs={budget:.1e}", - f"N={candidate.target_params:.1e}", - f"B={candidate.batch_size}", - f"steps={candidate.train_steps}", - f"tokens={candidate.tokens:.1e}", - ) - - # Static output path for checkpoint reuse - output_path = os.path.join("checkpoints", "isoflop", run_name) - - results.append( - IsoFlopTrainArgs( - candidate=candidate, - run_name=run_name, - tags=tags, - output_path=output_path, - ) - ) - - return results + """Generate training arguments for each candidate in an isoflop sweep.""" + ... def predict_optimal_config( self, @@ -440,53 +96,5 @@ def predict_optimal_config( steps_per_run: int = DEFAULT_STEPS_PER_RUN, flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, ) -> "CandidateConfig | None": - """Predict optimal training config for a target compute budget using fitted scaling laws. - - This implements IsoFLOP Approach 2 from the Chinchilla paper: - 1. Uses the scaling fit (N* ~ A * C^alpha) to predict optimal tokens for target_flops - 2. Generates candidate configs for the target budget using this recipe - 3. Selects the candidate whose token count is closest to the predicted optimal - - Args: - scaling_fits: Dict of {label: (alpha, A)} from scaling ladder result. - target_flops: Target compute budget in FLOPs. - label: Dataset/experiment label to use for scaling fit. - vocab_size: Vocabulary size. - seq_len: Sequence length for training. - steps_per_run: Reference step count for FLOP budget calculation. - flop_tolerance: Tolerance for matching FLOP budget. - - Returns: - CandidateConfig for the predicted optimal, or None if label not in fits - or no valid candidates found. - """ - import logging - - logger = logging.getLogger(__name__) - - if label not in scaling_fits: - logger.warning(f"Label '{label}' not found in scaling fits") - return None - - alpha, A = scaling_fits[label] - optimal_tokens = A * (target_flops**alpha) - - logger.info(f"Predicted optimal tokens for {target_flops:.2e} FLOPs: {optimal_tokens:.2e}") - - candidates = list(self.candidate_configs(target_flops, vocab_size, seq_len, steps_per_run, flop_tolerance)) - - if not candidates: - logger.warning(f"No valid candidates found for budget {target_flops:.2e}") - return None - - best = min(candidates, key=lambda c: c.tokens - optimal_tokens if c.tokens >= optimal_tokens else float("inf")) - # If all candidates have fewer tokens than optimal, pick the one with the most tokens - if best.tokens < optimal_tokens: - best = max(candidates, key=lambda c: c.tokens) - - logger.info( - f"Selected config: N={best.target_params:.2e}, " - f"B={best.batch_size}, tokens={best.tokens:.2e} (optimal: {optimal_tokens:.2e})" - ) - - return best + """Predict optimal training config for a target compute budget using fitted scaling laws.""" + ... diff --git a/lib/marin/src/marin/scaling_laws/tpu_utils.py b/lib/marin/src/marin/scaling_laws/tpu_utils.py index 165df2ab0e..21940a5985 100644 --- a/lib/marin/src/marin/scaling_laws/tpu_utils.py +++ b/lib/marin/src/marin/scaling_laws/tpu_utils.py @@ -85,7 +85,7 @@ def pick_v5p_type( candidate: "CandidateConfig", vocab_size: int, seq_len: int, - recipe: "ScalingRecipe | None" = None, + recipe: "ScalingRecipe", ) -> str: """Select the smallest TPU v5p slice that fits the model in float32. @@ -93,7 +93,7 @@ def pick_v5p_type( candidate: CandidateConfig with target_params and batch_size. vocab_size: Vocabulary size. seq_len: Sequence length. - recipe: ScalingRecipe to determine architecture. If None, uses default. + recipe: ScalingRecipe to determine architecture. Returns: TPU slice name, e.g., "v5p-8" or "v5p-32". @@ -101,11 +101,6 @@ def pick_v5p_type( Raises: ValueError: If the model is too large for available v5p slices. """ - if recipe is None: - from marin.scaling_laws.recipe import ScalingRecipe - - recipe = ScalingRecipe(name="default") - hidden_size = recipe.hidden_size_for_params(candidate.target_params, vocab_size) num_layers = recipe.compute_num_layers(hidden_size) diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py index 90c865b7af..23e5f1f039 100644 --- a/tests/test_scaling_laws.py +++ b/tests/test_scaling_laws.py @@ -24,7 +24,6 @@ from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig from levanter.models.qwen import Qwen3Config -from marin.scaling_laws import ScalingRecipe from marin.scaling_laws.isoflop_analysis import ( DEFAULT_SEQ_LEN, MARIN_TOKENIZER_VOCAB_SIZE, @@ -40,6 +39,9 @@ transform_metrics_for_isoflop, ) +# Import the concrete recipe from experiments for testing +from experiments.isoflop_sweep import Marin2025Recipe + # --- FLOP computation tests --- @@ -95,7 +97,7 @@ def test_parse_isoflop_run_name(): def test_candidate_configs_within_tolerance(): """Test that generated configs achieve the target FLOP budget within tolerance.""" - recipe = ScalingRecipe(name="test") + recipe = Marin2025Recipe() budget = 1e19 flop_tolerance = 0.01 seq_len = DEFAULT_SEQ_LEN @@ -147,7 +149,7 @@ def test_generate_isoflop_train_args_snapshot(): This ensures reproducibility of the config generation algorithm. """ - recipe = ScalingRecipe(name="test-snapshot") + recipe = Marin2025Recipe() result = generate_isoflop_train_args( budgets=(3e18,), experiment_name="test-snapshot", From 9cdeebf5978df2454b1df302df0d7cc10ab4ff09 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Thu, 8 Jan 2026 10:49:00 -0800 Subject: [PATCH 55/79] Wandb is always available --- .../src/marin/scaling_laws/eval_metrics_reader.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py index a8d8abef63..15f0354d88 100644 --- a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py +++ b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py @@ -29,16 +29,10 @@ import fsspec import pandas as pd +import wandb from marin.utilities.wandb_utils import WANDB_ENTITY, WANDB_PROJECT -try: - import wandb - - WANDB_AVAILABLE = True -except ImportError: - WANDB_AVAILABLE = False - logger = logging.getLogger(__name__) @@ -69,10 +63,6 @@ def _backfill_metrics_from_wandb( Returns: True if backfill succeeded, False otherwise """ - if not WANDB_AVAILABLE: - logger.warning(f"wandb not available, cannot backfill metrics for {checkpoint_path}") - return False - try: run_id = extract_run_name_from_path(checkpoint_path) logger.info(f"Attempting to backfill metrics for run_id: {run_id}") From 4421d2afea443e20c6647b87a5852d99fc664ddb Mon Sep 17 00:00:00 2001 From: Helw150 Date: Thu, 8 Jan 2026 12:08:01 -0800 Subject: [PATCH 56/79] Claude Code got out of hand here --- experiments/isoflop_sweep.py | 102 +++------------- lib/marin/src/marin/scaling_laws/__init__.py | 4 - .../marin/scaling_laws/isoflop_analysis.py | 111 ++++++++---------- lib/marin/src/marin/scaling_laws/recipe.py | 50 ++------ lib/marin/src/marin/scaling_laws/tpu_utils.py | 56 +-------- 5 files changed, 83 insertions(+), 240 deletions(-) diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index 826b13f48d..e9aef2efe0 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -19,8 +19,7 @@ """ import math -import os -from collections.abc import Iterator, Sequence +from collections.abc import Iterator from dataclasses import dataclass, replace from levanter.data.text import LMMixtureDatasetConfig @@ -47,7 +46,6 @@ DEFAULT_SEQ_LEN, DEFAULT_STEPS_PER_RUN, CandidateConfig, - IsoFlopTrainArgs, ScalingRecipe, generate_isoflop_train_args, pick_v5p_type, @@ -203,6 +201,24 @@ def _build_model_config_from_hidden_size(self, hidden_size: int, seq_len: int = rope=Llama3RotaryEmbeddingsConfig(), ) + def estimate_memory_bytes( + self, + model_config: LlamaConfig, + batch_size: int, + vocab_size: int, + optim_mult: int = 3, + dtype_size: int = 4, + fudge_factor: float = 2.0, + ) -> int: + """Estimate float32 memory usage in bytes for training.""" + param_count = self._compute_params_for_hidden_size(model_config.hidden_dim, vocab_size) + param_bytes = param_count * optim_mult * dtype_size + act_bytes = (batch_size * model_config.max_seq_len) * ( + (model_config.hidden_dim * model_config.num_layers) + vocab_size * fudge_factor + ) + total_bytes = param_bytes + act_bytes + return int(total_bytes * fudge_factor) + def build_optimizer_config(self, candidate: CandidateConfig, vocab_size: int) -> OptimizerConfig: """Build optimizer config for a candidate.""" hidden_size = self.hidden_size_for_params(candidate.target_params, vocab_size) @@ -267,86 +283,6 @@ def candidate_configs( flops_budget=budget, ) - def generate_isoflop_train_args( - self, - budgets: Sequence[float], - experiment_name: str, - vocab_size: int, - seq_len: int = DEFAULT_SEQ_LEN, - steps_per_run: int = DEFAULT_STEPS_PER_RUN, - flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, - ) -> list[IsoFlopTrainArgs]: - """Generate training arguments for each candidate in an isoflop sweep.""" - results: list[IsoFlopTrainArgs] = [] - - for budget in budgets: - for candidate in self.candidate_configs(budget, vocab_size, seq_len, steps_per_run, flop_tolerance): - run_name = ( - f"isoflop-{budget:.0e}-N{candidate.target_params:.0e}-" f"B{candidate.batch_size}-{experiment_name}" - ) - - tags = ( - f"FLOPs={budget:.1e}", - f"N={candidate.target_params:.1e}", - f"B={candidate.batch_size}", - f"steps={candidate.train_steps}", - f"tokens={candidate.tokens:.1e}", - ) - - output_path = os.path.join("checkpoints", "isoflop", run_name) - - results.append( - IsoFlopTrainArgs( - candidate=candidate, - run_name=run_name, - tags=tags, - output_path=output_path, - ) - ) - - return results - - def predict_optimal_config( - self, - scaling_fits: dict[str, tuple[float, float]], - target_flops: float, - label: str, - vocab_size: int, - seq_len: int = DEFAULT_SEQ_LEN, - steps_per_run: int = DEFAULT_STEPS_PER_RUN, - flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, - ) -> CandidateConfig | None: - """Predict optimal training config for a target compute budget using fitted scaling laws.""" - import logging - - logger = logging.getLogger(__name__) - - if label not in scaling_fits: - logger.warning(f"Label '{label}' not found in scaling fits") - return None - - alpha, A = scaling_fits[label] - optimal_tokens = A * (target_flops**alpha) - - logger.info(f"Predicted optimal tokens for {target_flops:.2e} FLOPs: {optimal_tokens:.2e}") - - candidates = list(self.candidate_configs(target_flops, vocab_size, seq_len, steps_per_run, flop_tolerance)) - - if not candidates: - logger.warning(f"No valid candidates found for budget {target_flops:.2e}") - return None - - best = min(candidates, key=lambda c: c.tokens - optimal_tokens if c.tokens >= optimal_tokens else float("inf")) - if best.tokens < optimal_tokens: - best = max(candidates, key=lambda c: c.tokens) - - logger.info( - f"Selected config: N={best.target_params:.2e}, " - f"B={best.batch_size}, tokens={best.tokens:.2e} (optimal: {optimal_tokens:.2e})" - ) - - return best - MARIN_2025_RECIPE = Marin2025Recipe() """Default Marin scaling recipe.""" diff --git a/lib/marin/src/marin/scaling_laws/__init__.py b/lib/marin/src/marin/scaling_laws/__init__.py index eb6222ae34..ec3735a551 100644 --- a/lib/marin/src/marin/scaling_laws/__init__.py +++ b/lib/marin/src/marin/scaling_laws/__init__.py @@ -28,7 +28,6 @@ ScalingFit, candidate_configs, compute_training_flops, - compute_transformer_params, fit_scaling_laws, generate_isoflop_train_args, predict_optimal_config, @@ -39,7 +38,6 @@ solve_for_train_steps, ) from marin.scaling_laws.tpu_utils import ( - estimate_memory_bytes, pick_v5p_type, ) from marin.scaling_laws.recipe import ( @@ -77,10 +75,8 @@ # Functions "candidate_configs", "compute_training_flops", - "compute_transformer_params", "create_isoflop_plot", "create_scaling_plot", - "estimate_memory_bytes", "fit_scaling_laws", "generate_isoflop_train_args", "pick_v5p_type", diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index d63741a5ee..2b8e0fdcf3 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -314,56 +314,6 @@ def solve_for_train_steps( return target_flops / (3 * flops_per_token * batch_size * seq_len) -def compute_transformer_params( - hidden_dim: int, - intermediate_dim: int, - num_layers: int, - num_heads: int, - num_kv_heads: int, - vocab_size: int, - tie_embeddings: bool = False, -) -> int: - """Compute parameter count for a standard transformer (Llama/Qwen architecture). - - This matches the formula used in Levanter's LlamaConfig.total_trainable_params(), - allowing parameter estimation without constructing a model config. - - Args: - hidden_dim: Model hidden dimension. - intermediate_dim: MLP intermediate dimension. - num_layers: Number of transformer layers. - num_heads: Number of attention heads. - num_kv_heads: Number of key-value heads (for GQA). - vocab_size: Vocabulary size. - tie_embeddings: Whether embeddings are tied (default False). - - Returns: - Total parameter count. - """ - token_embedding = vocab_size * hidden_dim - head_size = hidden_dim // num_heads - - # Attention: Q, K, V projections + output projection - q_proj = hidden_dim * head_size * num_heads - kv_proj = 2 * hidden_dim * head_size * num_kv_heads - o_proj = head_size * num_heads * hidden_dim - attn = q_proj + kv_proj + o_proj - - # MLP: gate, up, down projections (SwiGLU uses 3 matrices) - mlp = 3 * hidden_dim * intermediate_dim - - # Per-layer: attention + mlp + 2 RMSNorm - transformer_layer = attn + mlp + 2 * hidden_dim - - # Full transformer: layers + final RMSNorm - transformer = num_layers * transformer_layer + hidden_dim - - # LM head (separate unless tied) - lm_head = 0 if tie_embeddings else token_embedding - - return transformer + token_embedding + lm_head - - def candidate_configs( budget: float, vocab_size: int, @@ -433,7 +383,6 @@ def generate_isoflop_train_args( ) -> list[IsoFlopTrainArgs]: """Generate model-agnostic training arguments for each candidate in an isoflop sweep. - This is a convenience function that delegates to recipe.generate_isoflop_train_args(). Returns IsoFlopTrainArgs containing model-agnostic CandidateConfig objects. Use recipe.build_model_config() and recipe.build_optimizer_config() to get model-specific configs. @@ -465,9 +414,30 @@ def generate_isoflop_train_args( ... model_config = recipe.build_model_config(args.candidate.target_params, vocab_size) ... optimizer_config = recipe.build_optimizer_config(args.candidate, vocab_size) """ - return recipe.generate_isoflop_train_args( - budgets, experiment_name, vocab_size, seq_len, steps_per_run, flop_tolerance - ) + results: list[IsoFlopTrainArgs] = [] + + for budget in budgets: + for candidate in recipe.candidate_configs(budget, vocab_size, seq_len, steps_per_run, flop_tolerance): + run_name = f"isoflop-{budget:.0e}-N{candidate.target_params:.0e}-B{candidate.batch_size}-{experiment_name}" + tags = ( + f"FLOPs={budget:.1e}", + f"N={candidate.target_params:.1e}", + f"B={candidate.batch_size}", + f"steps={candidate.train_steps}", + f"tokens={candidate.tokens:.1e}", + ) + output_path = os.path.join("checkpoints", "isoflop", run_name) + + results.append( + IsoFlopTrainArgs( + candidate=candidate, + run_name=run_name, + tags=tags, + output_path=output_path, + ) + ) + + return results # ---------------- Helpers ---------------- @@ -743,10 +713,6 @@ def predict_optimal_config( ) -> CandidateConfig | None: """Predict optimal training config for a target compute budget using fitted scaling laws. - This is a convenience function that delegates to recipe.predict_optimal_config(). - The recipe encapsulates all model-specific decisions, while this function provides - backward compatibility. - This implements IsoFLOP Approach 2 from the Chinchilla paper: 1. D_opt (optimal tokens) is found empirically at each compute budget by fitting parabolas to actual loss values and finding the minimum. @@ -768,12 +734,33 @@ def predict_optimal_config( CandidateConfig for the predicted optimal, or None if label not in fits or no valid candidates found. """ - # Convert ScalingFit NamedTuples to plain tuples for recipe method - fits_as_tuples = {k: (v.alpha, v.A) for k, v in scaling_fits.items()} - return recipe.predict_optimal_config( - fits_as_tuples, target_flops, label, vocab_size, seq_len, steps_per_run, flop_tolerance + if label not in scaling_fits: + logger.warning(f"Label '{label}' not found in scaling fits") + return None + + alpha, A = scaling_fits[label] + optimal_tokens = A * (target_flops**alpha) + + logger.info(f"Predicted optimal tokens for {target_flops:.2e} FLOPs: {optimal_tokens:.2e}") + + candidates = list(recipe.candidate_configs(target_flops, vocab_size, seq_len, steps_per_run, flop_tolerance)) + + if not candidates: + logger.warning(f"No valid candidates found for budget {target_flops:.2e}") + return None + + # Find candidate with tokens >= optimal_tokens, closest to optimal + best = min(candidates, key=lambda c: c.tokens - optimal_tokens if c.tokens >= optimal_tokens else float("inf")) + if best.tokens < optimal_tokens: + best = max(candidates, key=lambda c: c.tokens) + + logger.info( + f"Selected config: N={best.target_params:.2e}, " + f"B={best.batch_size}, tokens={best.tokens:.2e} (optimal: {optimal_tokens:.2e})" ) + return best + def predict_optimal_configs_for_budgets( scaling_fits: dict[str, ScalingFit], diff --git a/lib/marin/src/marin/scaling_laws/recipe.py b/lib/marin/src/marin/scaling_laws/recipe.py index e88971aa24..2ed897b173 100644 --- a/lib/marin/src/marin/scaling_laws/recipe.py +++ b/lib/marin/src/marin/scaling_laws/recipe.py @@ -20,16 +20,19 @@ - Model config building (returns LlamaConfig or subclass) - Optimizer config building - Candidate generation for isoflop sweeps + +Orchestration logic (generating train args, predicting optimal configs) lives in +the library functions in isoflop_analysis.py, not in recipes. """ -from collections.abc import Iterator, Sequence +from collections.abc import Iterator from typing import TYPE_CHECKING, Protocol from levanter.models.llama import LlamaConfig from levanter.optim.config import OptimizerConfig if TYPE_CHECKING: - from marin.scaling_laws.isoflop_analysis import CandidateConfig, IsoFlopTrainArgs + from marin.scaling_laws.isoflop_analysis import CandidateConfig # Default constants DEFAULT_SEQ_LEN = 4096 @@ -40,25 +43,23 @@ class ScalingRecipe(Protocol): """Protocol defining the interface for scaling law recipes. - Concrete implementations (e.g., Marin2025Recipe) should implement all methods - with their specific hyperparameters and formulas. + Concrete implementations (e.g., Marin2025Recipe) should implement these + model-specific methods. Orchestration logic (generating training args, + predicting optimal configs) is handled by library functions that use + these core methods. """ name: str """Name identifying this recipe (e.g., 'marin-2025').""" - def compute_num_layers(self, hidden_size: int) -> int: - """Compute number of layers from hidden size using the recipe's depth-width formula.""" - ... - - def hidden_size_for_params(self, target_params: int, vocab_size: int) -> int: - """Find the hidden size that gives approximately target_params.""" - ... - def build_model_config(self, target_params: int, vocab_size: int, seq_len: int = DEFAULT_SEQ_LEN) -> LlamaConfig: """Build a model config for a target parameter count.""" ... + def estimate_memory_bytes(self, model_config: LlamaConfig, batch_size: int, vocab_size: int) -> int: + """Estimate memory usage in bytes for training with this model config.""" + ... + def build_optimizer_config(self, candidate: "CandidateConfig", vocab_size: int) -> OptimizerConfig: """Build optimizer config for a candidate.""" ... @@ -73,28 +74,3 @@ def candidate_configs( ) -> "Iterator[CandidateConfig]": """Yield candidate configurations within the FLOP budget.""" ... - - def generate_isoflop_train_args( - self, - budgets: Sequence[float], - experiment_name: str, - vocab_size: int, - seq_len: int = DEFAULT_SEQ_LEN, - steps_per_run: int = DEFAULT_STEPS_PER_RUN, - flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, - ) -> "list[IsoFlopTrainArgs]": - """Generate training arguments for each candidate in an isoflop sweep.""" - ... - - def predict_optimal_config( - self, - scaling_fits: "dict[str, tuple[float, float]]", - target_flops: float, - label: str, - vocab_size: int, - seq_len: int = DEFAULT_SEQ_LEN, - steps_per_run: int = DEFAULT_STEPS_PER_RUN, - flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, - ) -> "CandidateConfig | None": - """Predict optimal training config for a target compute budget using fitted scaling laws.""" - ... diff --git a/lib/marin/src/marin/scaling_laws/tpu_utils.py b/lib/marin/src/marin/scaling_laws/tpu_utils.py index 21940a5985..b176d45b75 100644 --- a/lib/marin/src/marin/scaling_laws/tpu_utils.py +++ b/lib/marin/src/marin/scaling_laws/tpu_utils.py @@ -38,49 +38,6 @@ """Available TPU v5p core configurations (slice sizes).""" -def estimate_memory_bytes( - param_count: int, - hidden_dim: int, - num_layers: int, - batch: int, - seq_len: int, - vocab: int, - optim_mult: int = 3, - dtype_size: int = 4, - fudge_factor: float = 2, -) -> int: - """Estimate float32 memory usage (in bytes) for one training step. - - This is a conservative estimate for LLaMA-style architectures with - Adam optimizer. The fudge_factor provides a safety margin for - additional memory overhead not captured in the simple model. - - Args: - param_count: Number of model parameters. - hidden_dim: Model hidden size. - num_layers: Number of Transformer layers. - batch: Training batch size. - seq_len: Sequence length. - vocab: Vocabulary size. - optim_mult: Optimizer memory multiplier (default 3 for Adam with - momentum and variance states). - dtype_size: Bytes per float (default 4 for float32). - fudge_factor: Safety margin multiplier (default 2x). - - Returns: - Estimated total memory in bytes. - - Note: - This assumes a LLaMA-style architecture with Adam optimizer in float32. - Actual memory usage may vary based on specific model architecture, - optimizer choice, and JAX/XLA memory optimizations. - """ - param_bytes = param_count * optim_mult * dtype_size - act_bytes = (batch * seq_len) * ((hidden_dim * num_layers) + vocab * fudge_factor) - total_bytes = param_bytes + act_bytes - return int(total_bytes * fudge_factor) - - def pick_v5p_type( candidate: "CandidateConfig", vocab_size: int, @@ -101,17 +58,8 @@ def pick_v5p_type( Raises: ValueError: If the model is too large for available v5p slices. """ - hidden_size = recipe.hidden_size_for_params(candidate.target_params, vocab_size) - num_layers = recipe.compute_num_layers(hidden_size) - - need_bytes = estimate_memory_bytes( - candidate.target_params, - hidden_size, - num_layers, - candidate.batch_size, - seq_len, - vocab_size, - ) + model_config = recipe.build_model_config(candidate.target_params, vocab_size, seq_len) + need_bytes = recipe.estimate_memory_bytes(model_config, candidate.batch_size, vocab_size) chip_bytes = HBM_PER_CHIP_GIB * 1024**3 chips = math.ceil(need_bytes / chip_bytes) cores_req = chips * CORES_PER_CHIP From 569e301279f66e8af7876e6660077f178889d488 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Thu, 8 Jan 2026 12:21:40 -0800 Subject: [PATCH 57/79] Note Differences --- lib/marin/src/marin/scaling_laws/recipe.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lib/marin/src/marin/scaling_laws/recipe.py b/lib/marin/src/marin/scaling_laws/recipe.py index 2ed897b173..7afea93c39 100644 --- a/lib/marin/src/marin/scaling_laws/recipe.py +++ b/lib/marin/src/marin/scaling_laws/recipe.py @@ -53,7 +53,11 @@ class ScalingRecipe(Protocol): """Name identifying this recipe (e.g., 'marin-2025').""" def build_model_config(self, target_params: int, vocab_size: int, seq_len: int = DEFAULT_SEQ_LEN) -> LlamaConfig: - """Build a model config for a target parameter count.""" + """Build a model config for a target parameter count. + + TODO: LlamaConfig is currently our most generic config type, but this + couples recipes to Llama-family architectures. Generalize when needed. + """ ... def estimate_memory_bytes(self, model_config: LlamaConfig, batch_size: int, vocab_size: int) -> int: From 015029214e2c807ec968d1b17c976e026da15c16 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Thu, 8 Jan 2026 13:56:09 -0800 Subject: [PATCH 58/79] Remove all 6ND --- .../src/marin/scaling_laws/isoflop_analysis.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index 2b8e0fdcf3..2128239558 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -564,13 +564,22 @@ def fit_scaling_laws( idx = (sub.tokens - D_star).abs().argmin() nearest_row = sub.iloc[idx] + # Require params to be present - the 6ND approximation is inaccurate for small models + params = nearest_row.get("params") + if params is None or pd.isna(params): + logger.warning( + f"Missing params for {lab} at {C:.1e} FLOPs - skipping. " + "Ensure runs log parameter_count or have full model config." + ) + continue + minima_records.append( MinimaRecord( label=lab, flops=float(C), optimal_tokens=D_star, loss_at_optimal=loss_opt, - optimal_params=float(nearest_row.get("params") or C / (6 * D_star)), + optimal_params=float(params), batch_size=int(nearest_row["batch_size"]), ) ) @@ -664,7 +673,7 @@ def transform_metrics_for_isoflop( logger.warning(f"Missing metric {metric_key} for run {run_name}") continue - # Get parameter count from summary + # Get parameter count from summary (required for accurate scaling analysis) params = summary.get(PARAMETER_COUNT_KEY) if params is None or pd.isna(params): params = None From 6ce4376064f1d4cf132378d23a651b394db1dd88 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Thu, 8 Jan 2026 14:48:02 -0800 Subject: [PATCH 59/79] Simplify --- .../exp2166_scaling_ladder_analysis.py | 8 +++- .../src/marin/scaling_laws/scaling_ladder.py | 45 ++----------------- 2 files changed, 9 insertions(+), 44 deletions(-) diff --git a/experiments/exp2166_scaling_ladder_analysis.py b/experiments/exp2166_scaling_ladder_analysis.py index 98d4fcbfaa..50becd8826 100644 --- a/experiments/exp2166_scaling_ladder_analysis.py +++ b/experiments/exp2166_scaling_ladder_analysis.py @@ -26,6 +26,7 @@ from experiments.defaults import default_validation_sets from experiments.isoflop_sweep import MARIN_2025_RECIPE, MARIN_SCALING_SUITES, nemotron_mix from marin.execution.executor import ExecutorStep, executor_main, output_path_of +from marin.processing.tokenize import add_validation_sets_to_mixture from marin.scaling_laws import ( IsoFlopAnalysisConfig, ScalingLadderRungConfig, @@ -40,6 +41,10 @@ TARGET_BUDGETS: list[float] = [1e18, 3e18, 6e18, 1e19, 3e19, 6e19, 1e20] EXPERIMENT_NAME = "exp2166-scaling-ladder-nemotron-validation" LABEL = "nemo-wider-depth-adapt" +TOKENIZER = "stanford-crfm/marin-tokenizer" + +# Add validation sets to the training mixture +nemotron_mix_with_validation = add_validation_sets_to_mixture(nemotron_mix, default_validation_sets(tokenizer=TOKENIZER)) # --- Step 1: IsoFLOP Analysis --- # Creates scaling law fits from the training runs @@ -64,10 +69,9 @@ analysis_output_path=output_path_of(analysis_step), target_budget=budget, label=LABEL, - tokenized=nemotron_mix, + tokenized=nemotron_mix_with_validation, output_path=f"checkpoints/{EXPERIMENT_NAME}-optimal-{budget:.0e}", recipe=MARIN_2025_RECIPE, - validation_sets=default_validation_sets(tokenizer="stanford-crfm/marin-tokenizer"), ), ) optimal_runs.append(step) diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py index bbadb88ffc..42988c9479 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -36,7 +36,6 @@ from levanter.utils.mesh import MeshConfig from marin.processing.tokenize import get_vocab_size_for_tokenizer -from marin.processing.tokenize.data_configs import add_validation_sets_to_mixture, lm_data_config from marin.scaling_laws.isoflop_analysis import ( ScalingFit, predict_optimal_config, @@ -48,39 +47,6 @@ logger = logging.getLogger(__name__) -def _prepare_data_config( - tokenized: str | LMMixtureDatasetConfig, - validation_sets: dict | None = None, -) -> LMMixtureDatasetConfig: - """Prepare a tokenized dataset for training. - - This is a local helper that prepares data configs without depending on - experiment-specific validation sets. Callers should pass validation sets - explicitly if needed. - - Args: - tokenized: The tokenized dataset - can be a path string or an - already-configured LMMixtureDatasetConfig. - validation_sets: Optional dict of validation sets to add. If None, - no validation sets are added. - - Returns: - LMMixtureDatasetConfig ready for training. - """ - if isinstance(tokenized, LMMixtureDatasetConfig): - pretraining_data = tokenized - if validation_sets: - pretraining_data = add_validation_sets_to_mixture(pretraining_data, validation_sets) - else: - # String path - pretraining_data = lm_data_config( - training_set=tokenized, - validation_sets=validation_sets, - permutation_type="feistel", - ) - return pretraining_data - - @dataclass(frozen=True) class ScalingLadderRungConfig: """Configuration for one rung of the scaling ladder (one compute-optimal training run). @@ -101,8 +67,8 @@ class ScalingLadderRungConfig: label: str """Dataset label to use for scaling fit (e.g., 'nemo', 'comma', 'dclm').""" - tokenized: str | LMMixtureDatasetConfig - """Tokenized dataset for training. Can be a path string or LMMixtureDatasetConfig.""" + tokenized: LMMixtureDatasetConfig + """Tokenized dataset for training (with validation sets already added).""" output_path: str """Where to write training outputs.""" @@ -116,9 +82,6 @@ class ScalingLadderRungConfig: seq_len: int = 4096 """Sequence length for training.""" - validation_sets: dict | None = None - """Optional validation sets to add for eval loss tracking.""" - def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: """Run one rung of the scaling ladder (one compute-optimal training run). @@ -166,10 +129,8 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: optimizer_cfg = config.recipe.build_optimizer_config(candidate, vocab_size) tpu_type = pick_v5p_type(candidate, vocab_size, config.seq_len, config.recipe) - pretraining_data = _prepare_data_config(config.tokenized, config.validation_sets) - train_config = TrainLmConfig( - data=pretraining_data, + data=config.tokenized, trainer=TrainerConfig( tracker=WandbConfig( project="marin", From 9c99d8dbb58d83eb272f5e9e32949365f6ec2c0c Mon Sep 17 00:00:00 2001 From: Helw150 Date: Thu, 8 Jan 2026 14:50:42 -0800 Subject: [PATCH 60/79] Comment Tweak --- lib/marin/src/marin/scaling_laws/recipe.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/marin/src/marin/scaling_laws/recipe.py b/lib/marin/src/marin/scaling_laws/recipe.py index 7afea93c39..e22596b2be 100644 --- a/lib/marin/src/marin/scaling_laws/recipe.py +++ b/lib/marin/src/marin/scaling_laws/recipe.py @@ -55,8 +55,9 @@ class ScalingRecipe(Protocol): def build_model_config(self, target_params: int, vocab_size: int, seq_len: int = DEFAULT_SEQ_LEN) -> LlamaConfig: """Build a model config for a target parameter count. - TODO: LlamaConfig is currently our most generic config type, but this - couples recipes to Llama-family architectures. Generalize when needed. + TODO: LlamaConfig is currently our most generic config type, and we + subclass it to other models (e.g. Qwen, OLMo, etc). We should make + a true generic config class eventually. """ ... From 3f1fa57d604455c61f1feaeaa35bd0a9b707b6a0 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Thu, 8 Jan 2026 15:18:22 -0800 Subject: [PATCH 61/79] Move Naming to the Experiment Code --- experiments/isoflop_sweep.py | 29 +++++++++++++++++-- .../marin/scaling_laws/isoflop_analysis.py | 20 ++++--------- tests/test_scaling_laws.py | 1 - 3 files changed, 31 insertions(+), 19 deletions(-) diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index e9aef2efe0..9345d52f9f 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -61,6 +61,20 @@ def _round_to_power_of_two(x: float) -> int: return 2 ** math.ceil(math.log2(x)) +def _format_run_name( + budget: float, + hidden_size: int, + num_layers: int, + batch_size: int, + experiment_name: str, +) -> str: + """Format run name using architecture details (hidden size and layers). + + Format: isoflop-{budget}-d{hidden}-L{layers}-B{batch}-{experiment_name} + """ + return f"isoflop-{budget:.0e}-d{hidden_size}-L{num_layers}-B{batch_size}-{experiment_name}" + + @dataclass(frozen=True) class Marin2025Recipe: """Marin 2025 scaling recipe with all hyperparameters and formulas. @@ -320,7 +334,6 @@ def create_isoflop_sweep_steps( # Library provides the training arguments (model configs, optimizer configs, etc.) train_args_list = generate_isoflop_train_args( budgets=budgets, - experiment_name=experiment_name, vocab_size=vocab_size, recipe=recipe, ) @@ -346,6 +359,16 @@ def create_isoflop_sweep_steps( optimizer_config = recipe.build_optimizer_config(candidate, vocab_size) tpu_type = pick_v5p_type(candidate, vocab_size, seq_len, recipe) + # Use local naming with architecture details for backward compatibility + run_name = _format_run_name( + candidate.flops_budget, + model_config.hidden_dim, + model_config.num_layers, + candidate.batch_size, + experiment_name, + ) + output_path = f"checkpoints/isoflop/{run_name}" + train_cfg = replace( base_train_config, train_batch_size=candidate.batch_size, @@ -357,7 +380,7 @@ def create_isoflop_sweep_steps( # Create training step train_step = default_train( - name=args.run_name, + name=run_name, tokenized=tokenized, model_config=model_config, train_config=train_cfg, @@ -366,7 +389,7 @@ def create_isoflop_sweep_steps( ) # Pin to static output path for checkpoint reuse - train_step = train_step.with_output_path(args.output_path) + train_step = train_step.with_output_path(output_path) train_steps.append(train_step) candidates.append(candidate) diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index 2128239558..4b8e4a3775 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -144,8 +144,11 @@ class IsoFlopTrainArgs: The ScalingRecipe is responsible for converting these to model-specific configs (model architecture, optimizer hyperparameters). + Naming (run_name, output_path) is intentionally not included here - that's + the responsibility of experiment code which may have its own conventions. + Example: - >>> args = generate_isoflop_train_args(budgets, "my-exp", vocab_size, recipe)[0] + >>> args = generate_isoflop_train_args(budgets, vocab_size, recipe)[0] >>> # Recipe converts candidate to model-specific configs >>> model_config = recipe.build_model_config(args.candidate.target_params, vocab_size) >>> optimizer_config = recipe.build_optimizer_config(args.candidate) @@ -154,15 +157,9 @@ class IsoFlopTrainArgs: candidate: CandidateConfig """Model-agnostic compute allocation (batch_size, train_steps, tokens, target_params).""" - run_name: str - """Name for the training run.""" - tags: tuple[str, ...] """Tags for tracking/filtering runs.""" - output_path: str - """Static output path for checkpoints.""" - # ---------------- Typed Records ---------------- @@ -374,7 +371,6 @@ def _minima_to_candidates( def generate_isoflop_train_args( budgets: Sequence[float], - experiment_name: str, vocab_size: int, recipe: ScalingRecipe, seq_len: int = DEFAULT_SEQ_LEN, @@ -385,11 +381,10 @@ def generate_isoflop_train_args( Returns IsoFlopTrainArgs containing model-agnostic CandidateConfig objects. Use recipe.build_model_config() and recipe.build_optimizer_config() to get - model-specific configs. + model-specific configs. Naming (run_name, output_path) is left to the caller. Args: budgets: Sequence of FLOP budgets to generate configs for. - experiment_name: Name suffix for run names (e.g., 'nemo', 'dclm'). vocab_size: Vocabulary size for the tokenizer. recipe: ScalingRecipe with architecture/hyperparameter settings. seq_len: Sequence length for training. @@ -405,7 +400,6 @@ def generate_isoflop_train_args( >>> # recipe = Marin2025Recipe() >>> train_args = generate_isoflop_train_args( ... budgets=DEFAULT_BUDGETS, - ... experiment_name="my-experiment", ... vocab_size=128256, ... recipe=recipe, ... ) @@ -418,7 +412,6 @@ def generate_isoflop_train_args( for budget in budgets: for candidate in recipe.candidate_configs(budget, vocab_size, seq_len, steps_per_run, flop_tolerance): - run_name = f"isoflop-{budget:.0e}-N{candidate.target_params:.0e}-B{candidate.batch_size}-{experiment_name}" tags = ( f"FLOPs={budget:.1e}", f"N={candidate.target_params:.1e}", @@ -426,14 +419,11 @@ def generate_isoflop_train_args( f"steps={candidate.train_steps}", f"tokens={candidate.tokens:.1e}", ) - output_path = os.path.join("checkpoints", "isoflop", run_name) results.append( IsoFlopTrainArgs( candidate=candidate, - run_name=run_name, tags=tags, - output_path=output_path, ) ) diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py index 23e5f1f039..fc4b2e8b6e 100644 --- a/tests/test_scaling_laws.py +++ b/tests/test_scaling_laws.py @@ -152,7 +152,6 @@ def test_generate_isoflop_train_args_snapshot(): recipe = Marin2025Recipe() result = generate_isoflop_train_args( budgets=(3e18,), - experiment_name="test-snapshot", vocab_size=MARIN_TOKENIZER_VOCAB_SIZE, recipe=recipe, ) From f50b879771154cbbe485cbf236df924451b978eb Mon Sep 17 00:00:00 2001 From: Helw150 Date: Thu, 8 Jan 2026 16:08:03 -0800 Subject: [PATCH 62/79] Legacy Support --- experiments/isoflop_sweep.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index 9345d52f9f..a0eba80478 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -41,10 +41,6 @@ from marin.execution.executor import ExecutorStep, InputName, executor_main from marin.processing.tokenize import get_vocab_size_for_tokenizer, lm_mixture_data_config from marin.scaling_laws import ( - DEFAULT_BUDGETS, - DEFAULT_FLOP_TOLERANCE, - DEFAULT_SEQ_LEN, - DEFAULT_STEPS_PER_RUN, CandidateConfig, ScalingRecipe, generate_isoflop_train_args, @@ -53,6 +49,12 @@ solve_for_train_steps, ) +DEFAULT_BUDGETS: tuple[float, ...] = (1e18, 3e18, 6e18, 1e19, 3e19, 6e19, 1e20) +LEGACY_BUDGETS: tuple[float, ...] = (3e18, 9e18, 1.8e19, 3e19, 9e19, 1.8e20, 3e20) +DEFAULT_SEQ_LEN: int = 4096 +DEFAULT_STEPS_PER_RUN: int = 2**16 +DEFAULT_FLOP_TOLERANCE: float = 0.01 + def _round_to_power_of_two(x: float) -> int: """Round x UP to the nearest power of 2.""" @@ -438,26 +440,31 @@ def create_isoflop_sweep_steps( tokenized=nemotron_mix, experiment_name="nemo-wider-depth-adapt", recipe=MARIN_2025_RECIPE, + budgets=LEGACY_BUDGETS, ), "common_pile": create_isoflop_sweep_steps( tokenized=comma_main_mixture(permutation_type="linear"), experiment_name="comma-mix", recipe=MARIN_2025_RECIPE, + budgets=LEGACY_BUDGETS, ), "common_pile_feistel": create_isoflop_sweep_steps( tokenized=comma_main_mixture(permutation_type="feistel"), experiment_name="comma-mix-feistel", recipe=MARIN_2025_RECIPE, + budgets=LEGACY_BUDGETS, ), "dclm-default": create_isoflop_sweep_steps( tokenized=dclm_mix, experiment_name="dclm-default", recipe=MARIN_2025_RECIPE, + budgets=LEGACY_BUDGETS, ), "dolma3_mix_150b": create_isoflop_sweep_steps( tokenized=dolma3_mix, experiment_name="dolma3-mix-150b-1025", recipe=MARIN_2025_RECIPE, + budgets=LEGACY_BUDGETS, ), } From c04ac2dcf88418a68cde20eb62388f21d243e45d Mon Sep 17 00:00:00 2001 From: Helw150 Date: Thu, 8 Jan 2026 16:46:27 -0800 Subject: [PATCH 63/79] Path Bugs and Logging Bugs in training --- experiments/exp2166_scaling_ladder_analysis.py | 6 +++--- lib/levanter/src/levanter/tracker/wandb.py | 7 +++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/experiments/exp2166_scaling_ladder_analysis.py b/experiments/exp2166_scaling_ladder_analysis.py index 50becd8826..135d31f6ef 100644 --- a/experiments/exp2166_scaling_ladder_analysis.py +++ b/experiments/exp2166_scaling_ladder_analysis.py @@ -25,7 +25,7 @@ from experiments.defaults import default_validation_sets from experiments.isoflop_sweep import MARIN_2025_RECIPE, MARIN_SCALING_SUITES, nemotron_mix -from marin.execution.executor import ExecutorStep, executor_main, output_path_of +from marin.execution.executor import ExecutorStep, executor_main, output_path_of, this_output_path from marin.processing.tokenize import add_validation_sets_to_mixture from marin.scaling_laws import ( IsoFlopAnalysisConfig, @@ -53,7 +53,7 @@ fn=run_isoflop_analysis_step, config=IsoFlopAnalysisConfig( training_runs=[output_path_of(r) for r in nemotron_training], - output_path=f"analysis/{EXPERIMENT_NAME}", + output_path=this_output_path(), recipe=MARIN_2025_RECIPE, ), ) @@ -70,7 +70,7 @@ target_budget=budget, label=LABEL, tokenized=nemotron_mix_with_validation, - output_path=f"checkpoints/{EXPERIMENT_NAME}-optimal-{budget:.0e}", + output_path=this_output_path(), recipe=MARIN_2025_RECIPE, ), ) diff --git a/lib/levanter/src/levanter/tracker/wandb.py b/lib/levanter/src/levanter/tracker/wandb.py index 331d36e6f9..81333b0ca2 100644 --- a/lib/levanter/src/levanter/tracker/wandb.py +++ b/lib/levanter/src/levanter/tracker/wandb.py @@ -134,6 +134,13 @@ def _convert_value_to_loggable_rec(value: Any): return value.item() else: return np.array(value) + elif isinstance(value, np.ndarray): + if value.ndim == 0: + return value.item() + else: + return value.tolist() + elif isinstance(value, np.generic): + return value.item() elif isinstance(value, Histogram): import wandb From f946f83e809065b73bbb49b3231b81c143f7d543 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Thu, 8 Jan 2026 17:27:29 -0800 Subject: [PATCH 64/79] Serialization Issues --- lib/levanter/src/levanter/eval.py | 7 ++++- .../marin/scaling_laws/isoflop_analysis.py | 31 +++++-------------- 2 files changed, 14 insertions(+), 24 deletions(-) diff --git a/lib/levanter/src/levanter/eval.py b/lib/levanter/src/levanter/eval.py index 7407c1f631..e8191d3abc 100644 --- a/lib/levanter/src/levanter/eval.py +++ b/lib/levanter/src/levanter/eval.py @@ -233,7 +233,12 @@ def eval_callback(step: StepInfo): fs, _, _ = fsspec.get_fs_token_paths(metrics_file) fs.makedirs(checkpoint_path, exist_ok=True) with fs.open(metrics_file, "a") as f: - record = {"step": int(step_count), **metrics_to_write} + # Convert numpy/jax floats to Python floats for JSON serialization + serializable_metrics = { + k: float(v) if isinstance(v, (np.floating, jnp.floating)) else v + for k, v in metrics_to_write.items() + } + record = {"step": int(step_count), **serializable_metrics} f.write(json.dumps(record, sort_keys=True) + "\n") return diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index 4b8e4a3775..90835eda0a 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -202,33 +202,18 @@ def round_to_power_of_two(x: float) -> int: return 2 ** math.ceil(math.log2(x)) -def round_flops_to_bucket(flops: float) -> float: - """Round FLOP count to 1 significant figure (XeYY format). - - This ensures runs with slightly different achieved FLOPs are grouped - together for analysis when they were targeting the same budget. - Using 1 significant figure creates buckets at 1e19, 2e19, 3e19, etc., - which matches the typical spacing of isoflop budget targets. - - Note: This means 1.5e19 and 2.4e19 both map to 2e19. For finer granularity, - consider using 2 significant figures (round to nearest 0.1 mantissa). - - Examples: - 1.05e19 → 1e19 - 1.5e19 → 2e19 - 2.8e19 → 3e19 - 9.5e19 → 1e20 +def round_flops_to_bucket(flops: float, base: float = 1.1) -> float: + """Round FLOP count to the nearest power of base. + + Args: + flops: FLOP count to round. + base: Base for the power buckets (default 1.1 for ~10% buckets). """ if flops <= 0: return flops - exponent = math.floor(math.log10(flops)) - mantissa = flops / (10**exponent) - rounded_mantissa = round(mantissa) - - if rounded_mantissa == 10: - return 1.0 * (10 ** (exponent + 1)) - return float(rounded_mantissa) * (10**exponent) + k = math.log(flops) / math.log(base) + return base ** round(k) def compute_training_flops( From 114e8282b734d5fb62a3d0bae2b105ec026b659f Mon Sep 17 00:00:00 2001 From: Helw150 Date: Fri, 9 Jan 2026 13:50:15 -0800 Subject: [PATCH 65/79] More Grugging --- .../exp2166_scaling_ladder_analysis.py | 12 +- experiments/isoflop_sweep.py | 247 +++++++++ lib/marin/src/marin/scaling_laws/__init__.py | 20 +- .../marin/scaling_laws/eval_metrics_reader.py | 25 +- .../marin/scaling_laws/isoflop_analysis.py | 507 +++++------------- lib/marin/src/marin/scaling_laws/recipe.py | 81 --- .../src/marin/scaling_laws/scaling_ladder.py | 2 +- lib/marin/src/marin/scaling_laws/tpu_utils.py | 3 +- tests/test_scaling_laws.py | 40 +- 9 files changed, 414 insertions(+), 523 deletions(-) delete mode 100644 lib/marin/src/marin/scaling_laws/recipe.py diff --git a/experiments/exp2166_scaling_ladder_analysis.py b/experiments/exp2166_scaling_ladder_analysis.py index 135d31f6ef..ce224cdbf8 100644 --- a/experiments/exp2166_scaling_ladder_analysis.py +++ b/experiments/exp2166_scaling_ladder_analysis.py @@ -24,13 +24,17 @@ """ from experiments.defaults import default_validation_sets -from experiments.isoflop_sweep import MARIN_2025_RECIPE, MARIN_SCALING_SUITES, nemotron_mix +from experiments.isoflop_sweep import ( + IsoFlopAnalysisConfig, + MARIN_2025_RECIPE, + MARIN_SCALING_SUITES, + nemotron_mix, + run_isoflop_analysis_step, +) from marin.execution.executor import ExecutorStep, executor_main, output_path_of, this_output_path from marin.processing.tokenize import add_validation_sets_to_mixture from marin.scaling_laws import ( - IsoFlopAnalysisConfig, ScalingLadderRungConfig, - run_isoflop_analysis_step, run_scaling_ladder_rung, ) @@ -52,7 +56,7 @@ name=f"{EXPERIMENT_NAME}-analysis", fn=run_isoflop_analysis_step, config=IsoFlopAnalysisConfig( - training_runs=[output_path_of(r) for r in nemotron_training], + training_runs=tuple(output_path_of(r) for r in nemotron_training), output_path=this_output_path(), recipe=MARIN_2025_RECIPE, ), diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index a0eba80478..283a942342 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -18,7 +18,10 @@ sizes while keeping the total training FLOPs roughly constant. """ +import logging import math +import os +import re from collections.abc import Iterator from dataclasses import dataclass, replace @@ -42,12 +45,20 @@ from marin.processing.tokenize import get_vocab_size_for_tokenizer, lm_mixture_data_config from marin.scaling_laws import ( CandidateConfig, + FitScalingLawsResult, + IsoFlopRecord, ScalingRecipe, + fit_scaling_laws, generate_isoflop_train_args, pick_v5p_type, + round_flops_to_bucket, solve_for_batch_size, solve_for_train_steps, ) +from marin.scaling_laws.eval_metrics_reader import read_raw_records +from marin.utilities.wandb_utils import WANDB_ENTITY, WANDB_PROJECT + +logger = logging.getLogger(__name__) DEFAULT_BUDGETS: tuple[float, ...] = (1e18, 3e18, 6e18, 1e19, 3e19, 6e19, 1e20) LEGACY_BUDGETS: tuple[float, ...] = (3e18, 9e18, 1.8e19, 3e19, 9e19, 1.8e20, 3e20) @@ -55,6 +66,125 @@ DEFAULT_STEPS_PER_RUN: int = 2**16 DEFAULT_FLOP_TOLERANCE: float = 0.01 +# ---------------- Levanter WandB Metric Keys ---------------- +# These keys correspond to the metrics logged by Levanter's training callbacks. +THROUGHPUT_TOKENS_KEY = "throughput/total_tokens" +THROUGHPUT_GFLOPS_KEY = "throughput/total_gflops" +PARAMETER_COUNT_KEY = "parameter_count" +MODEL_CONFIG_KEY = "model" +TRAINER_CONFIG_KEY = "trainer" +DEFAULT_METRIC_KEY = "eval/paloma/c4_en/bpb" + + +# ---------------- Levanter Metrics Transform ---------------- + + +def parse_isoflop_run_name(run_name: str) -> str | None: + """Parse experiment name from isoflop run name. + + Supports two formats: + - New: isoflop-{budget}-N{params}-B{batch}-{experiment_name} + E.g., 'isoflop-1e+18-N1e+08-B128-nemo-wider-depth-adapt' + - Legacy: isoflop-{budget}-d{hidden}-L{layers}-B{batch}-{experiment_name} + E.g., 'isoflop-1e+19-d2048-L16-B1024-nemo-wider-depth-adapt' + + Optionally with a trailing - which is ignored. + + Returns experiment_name or None if parsing fails. + """ + # Strip optional - suffix + run_name = re.sub(r"-[0-9a-fA-F]{6}$", "", run_name) + + # New format: isoflop-{budget}-N{params}-B{batch}-{experiment_name} + new_pattern = r"isoflop-(?:[0-9.e+]+)-N(?:[0-9.e+]+)-B(?:\d+)-(.+)" + match = re.match(new_pattern, run_name) + if match: + return match.group(1) + + # Legacy format: isoflop-{budget}-d{hidden}-L{layers}-B{batch}-{experiment_name} + legacy_pattern = r"isoflop-(?:[0-9.e+]+)-d(?:\d+)-L(?:\d+)-B(?:\d+)-(.+)" + match = re.match(legacy_pattern, run_name) + if match: + return match.group(1) + + return None + + +def transform_levanter_metrics( + raw_records: list[dict], + metric_key: str = DEFAULT_METRIC_KEY, + label_map: dict[str, str] | None = None, + min_flops: float = 1e18, +) -> list[IsoFlopRecord]: + """Transform raw Levanter metrics into IsoFlopRecord list. + + Args: + raw_records: Raw records from read_raw_records(), each containing + 'config', 'summary', and 'run_path' keys. + metric_key: Which metric to use (default: eval/paloma/c4_en/bpb). + label_map: Optional mapping from experiment_name -> display label. + min_flops: Minimum FLOP threshold to include (default: 1e18). + + Returns: + List of IsoFlopRecord for records that have all required fields. + Records missing required fields are logged and skipped. + """ + records = [] + + for raw in raw_records: + run_path = raw.get("run_path", "") + run_name = os.path.basename(run_path.rstrip("/")) + + summary = raw.get("summary", {}) or {} + + # Extract tokens + tokens = summary.get(THROUGHPUT_TOKENS_KEY) + if tokens is None: + logger.warning(f"Missing {THROUGHPUT_TOKENS_KEY} for run {run_name}, skipping") + continue + + # Extract FLOPs (convert GFLOPs to FLOPs and bucket) + total_gflops = summary.get(THROUGHPUT_GFLOPS_KEY) + if total_gflops is None: + logger.warning(f"Missing {THROUGHPUT_GFLOPS_KEY} for run {run_name}, skipping") + continue + flops = round_flops_to_bucket(total_gflops * 1e9) + + if flops < min_flops: + continue + + # Extract metric + metric = summary.get(metric_key) + if metric is None: + logger.warning(f"Missing metric {metric_key} for run {run_name}, skipping") + continue + + # Extract params (required) + params = summary.get(PARAMETER_COUNT_KEY) + if params is None: + logger.warning(f"Missing {PARAMETER_COUNT_KEY} for run {run_name}, skipping") + continue + + # Determine label from run name + exp_name = parse_isoflop_run_name(run_name) or run_name + if label_map and exp_name in label_map: + label = label_map[exp_name] + else: + label = exp_name + + records.append( + IsoFlopRecord( + tokens=float(tokens), + metric=float(metric), + flops=float(flops), + params=float(params), + label=label, + ) + ) + + logger.info(f"Transformed {len(records)} records from {len(raw_records)} raw records") + return records + def _round_to_power_of_two(x: float) -> int: """Round x UP to the nearest power of 2.""" @@ -304,6 +434,123 @@ def candidate_configs( """Default Marin scaling recipe.""" +# ---------------- IsoFlop Analysis ---------------- + + +@dataclass(frozen=True, kw_only=True) +class IsoFlopAnalysisConfig: + """Configuration for IsoFLOP scaling law analysis. + + The training_runs field creates blocking dependencies on the training jobs. + This config is for use with ExecutorStep. + """ + + training_runs: tuple[str, ...] + """Training run output paths (executor resolves InputName to str at runtime).""" + + output_path: str + """Where to write analysis outputs.""" + + recipe: ScalingRecipe + """Scaling recipe for computing optimal hyperparameters.""" + + metric_key: str = DEFAULT_METRIC_KEY + """Metric to use for loss (default: eval/paloma/c4_en/bpb).""" + + label_map: tuple[tuple[str, str], ...] | None = None + """Optional mapping from experiment_name -> display label as tuple of pairs.""" + + metrics_filename: str = "tracker_metrics.jsonl" + """Name of the metrics file within each checkpoint directory.""" + + backfill_from_wandb: bool = True + """If True, backfill tracker_metrics.jsonl from WandB for runs that completed before this feature.""" + + wandb_entity_project: str = f"{WANDB_ENTITY}/{WANDB_PROJECT}" + """WandB entity/project to query for backfill (format: 'entity/project').""" + + +def run_isoflop_analysis_step(config: IsoFlopAnalysisConfig) -> FitScalingLawsResult: + """Execute IsoFLOP scaling law analysis. + + This is the experiment step function that: + 1. Reads raw metrics from training runs + 2. Transforms them using Levanter schema knowledge + 3. Runs the scaling law analysis + 4. Saves results to output_path + + Args: + config: Analysis config with training_runs and analysis settings + + Returns: + FitScalingLawsResult with fitted scaling laws + """ + import json + + import fsspec + + # Read raw records from training runs + raw_records = read_raw_records(config) + + if not raw_records: + logger.warning("No eval metrics found in training runs") + return FitScalingLawsResult(minima_records=[], scaling_fits={}, fit_curves={}) + + # Transform to typed records using Levanter schema knowledge + label_map = dict(config.label_map) if config.label_map else None + records = transform_levanter_metrics(raw_records, config.metric_key, label_map) + + if not records: + logger.warning("No valid isoflop data after transformation") + return FitScalingLawsResult(minima_records=[], scaling_fits={}, fit_curves={}) + + logger.info(f"Loaded {len(records)} runs for scaling law analysis") + labels = list(dict.fromkeys(r.label for r in records)) + flops_budgets = sorted(set(r.flops for r in records)) + logger.info(f"Labels found: {labels}") + logger.info(f"FLOP budgets: {flops_budgets}") + + # Run scaling law analysis + result = fit_scaling_laws(records) + + logger.info(f"Found {len(result.minima_records)} optimal configurations") + for label, scaling_fit in result.scaling_fits.items(): + logger.info(f" {label}: D* = {scaling_fit.A:.2e} * C^{scaling_fit.alpha:.3f}") + + # Save results + fs, _, _ = fsspec.get_fs_token_paths(config.output_path) + fs.makedirs(config.output_path, exist_ok=True) + + result_path = os.path.join(config.output_path, "isoflop_analysis_result.json") + result_dict = { + "minima_records": [ + { + "label": r.label, + "flops": r.flops, + "optimal_tokens": r.optimal_tokens, + "loss_at_optimal": r.loss_at_optimal, + "optimal_params": r.optimal_params, + "scaling_alpha": r.scaling_alpha, + "scaling_A": r.scaling_A, + } + for r in result.minima_records + ], + "scaling_fits": {k: list(v) for k, v in result.scaling_fits.items()}, + } + with fs.open(result_path, "w") as f: + json.dump(result_dict, f, indent=2) + logger.info(f"Saved results to {result_path}") + + # Save fit curves for downstream plotting + fit_curves_path = os.path.join(config.output_path, "fit_curves.json") + fit_curves_json = {f"{label}|{flops}": list(coeffs) for (label, flops), coeffs in result.fit_curves.items()} + with fs.open(fit_curves_path, "w") as f: + json.dump(fit_curves_json, f, indent=2) + logger.info(f"Saved fit curves to {fit_curves_path}") + + return result + + def create_isoflop_sweep_steps( tokenized: InputName | str | LMMixtureDatasetConfig, experiment_name: str, diff --git a/lib/marin/src/marin/scaling_laws/__init__.py b/lib/marin/src/marin/scaling_laws/__init__.py index ec3735a551..e76f1c0734 100644 --- a/lib/marin/src/marin/scaling_laws/__init__.py +++ b/lib/marin/src/marin/scaling_laws/__init__.py @@ -20,29 +20,26 @@ DEFAULT_STEPS_PER_RUN, CandidateConfig, FitScalingLawsResult, - IsoFlopAnalysisConfig, - IsoFlopAnalysisResult, + IsoFlopRecord, IsoFlopTrainArgs, MinimaRecord, + ModelConfiguration, QuadraticFitCoeffs, ScalingFit, + ScalingRecipe, candidate_configs, compute_training_flops, fit_scaling_laws, generate_isoflop_train_args, predict_optimal_config, predict_optimal_configs_for_budgets, - run_isoflop_analysis, - run_isoflop_analysis_step, + round_flops_to_bucket, solve_for_batch_size, solve_for_train_steps, ) from marin.scaling_laws.tpu_utils import ( pick_v5p_type, ) -from marin.scaling_laws.recipe import ( - ScalingRecipe, -) from marin.scaling_laws.scaling_ladder import ( ScalingLadderRungConfig, run_scaling_ladder_rung, @@ -61,13 +58,13 @@ "DEFAULT_FLOP_TOLERANCE", "DEFAULT_SEQ_LEN", "DEFAULT_STEPS_PER_RUN", - # Data classes + # Data classes and Protocols "CandidateConfig", "FitScalingLawsResult", - "IsoFlopAnalysisConfig", - "IsoFlopAnalysisResult", + "IsoFlopRecord", "IsoFlopTrainArgs", "MinimaRecord", + "ModelConfiguration", "QuadraticFitCoeffs", "ScalingFit", "ScalingLadderRungConfig", @@ -82,8 +79,7 @@ "pick_v5p_type", "predict_optimal_config", "predict_optimal_configs_for_budgets", - "run_isoflop_analysis", - "run_isoflop_analysis_step", + "round_flops_to_bucket", "run_scaling_ladder_rung", "save_plots", "solve_for_batch_size", diff --git a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py index 15f0354d88..1f3132f55f 100644 --- a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py +++ b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py @@ -14,21 +14,17 @@ """Base infrastructure for eval metrics analysis. -This module provides a base config and utilities for analysis jobs that -read tracker_metrics.jsonl files from completed training runs. The subclassing -pattern mirrors the Evaluator approach in -lib/marin/src/marin/evaluation/evaluators/evaluator.py, so specific analyses -(like IsoFlop) should subclass EvalMetricsAnalysisConfig. +This module provides a config and utilities for analysis jobs that +read tracker_metrics.jsonl files from completed training runs. """ import json import logging import os -from dataclasses import dataclass from collections.abc import Sequence +from dataclasses import dataclass import fsspec -import pandas as pd import wandb from marin.utilities.wandb_utils import WANDB_ENTITY, WANDB_PROJECT @@ -92,9 +88,8 @@ def _backfill_metrics_from_wandb( @dataclass(frozen=True) class EvalMetricsAnalysisConfig: - """Base config for analyses that read eval metrics from training runs. + """Config for analyses that read eval metrics from training runs. - Subclass this to create specific analysis types (e.g., IsoFlopAnalysisConfig). The training_runs field creates blocking dependencies on the training jobs. """ @@ -114,9 +109,8 @@ class EvalMetricsAnalysisConfig: """WandB entity/project to query for backfill (format: 'entity/project').""" -def read_metrics_dataframe(config: EvalMetricsAnalysisConfig) -> pd.DataFrame: - """ - Read eval metrics from training runs into a DataFrame. +def read_raw_records(config: EvalMetricsAnalysisConfig) -> list[dict]: + """Read raw eval metrics from training runs. This is the shared utility that all analysis subtypes use to load metrics. It handles reading JSONL files and optional WandB backfill. @@ -125,7 +119,7 @@ def read_metrics_dataframe(config: EvalMetricsAnalysisConfig) -> pd.DataFrame: config: Analysis config with training_runs and backfill settings Returns: - DataFrame with columns: step, run_index, run_path, + all eval/* metrics + List of raw records, each containing config, summary, run_index, and run_path. """ all_records = [] @@ -167,9 +161,6 @@ def read_metrics_dataframe(config: EvalMetricsAnalysisConfig) -> pd.DataFrame: if not all_records: logger.warning("No eval metrics found in any training runs") - return pd.DataFrame() - df = pd.DataFrame(all_records) logger.info(f"Loaded {len(all_records)} evaluation records from {len(config.training_runs)} runs") - logger.info(f"Available columns: {list(df.columns)}") - return df + return all_records diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index 90835eda0a..68feb33ce0 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -14,36 +14,31 @@ """IsoFLOP analysis for finding compute-optimal training configurations. -This module provides functions and configs for IsoFLOP scaling law analysis: -1. Read eval metrics from completed training runs -2. Fit scaling laws to find compute-optimal token counts -3. Save results to JSON/parquet files - -For programmatic use (without ExecutorStep), see `run_isoflop_analysis()`. +This module provides the core data types and analysis functions for IsoFLOP +scaling law analysis. It is intentionally schema-agnostic - experiment code +should transform raw metrics into IsoFlopRecord before calling these functions. + +Key types: +- IsoFlopRecord: The contract for a single training run's metrics +- FitScalingLawsResult: Output from fit_scaling_laws() +- CandidateConfig: Model-agnostic compute allocation from scaling law analysis + +Key functions: +- fit_scaling_laws(records): Fit scaling laws from typed records +- predict_optimal_config(): Predict optimal training config for a target budget +- generate_isoflop_train_args(): Generate training args for an isoflop sweep """ -import json import logging import math -import os -import re from collections.abc import Iterator, Sequence -from dataclasses import asdict, dataclass, field -from typing import NamedTuple +from dataclasses import dataclass +from typing import NamedTuple, Protocol -import fsspec import jax.numpy as jnp -import pandas as pd from jaxopt import ScipyMinimize -from levanter.models.llama import LlamaConfig - -from marin.scaling_laws.eval_metrics_reader import ( - EvalMetricsAnalysisConfig, - extract_run_name_from_path, - read_metrics_dataframe, -) -from marin.scaling_laws.recipe import ScalingRecipe +from levanter.optim.config import OptimizerConfig logger = logging.getLogger(__name__) @@ -63,17 +58,6 @@ # This matches how FLOPs are tracked in WandB via Levanter's log_performance_stats. DEFAULT_BUDGETS: tuple[float, ...] = (1e18, 3e18, 6e18, 1e19, 3e19, 6e19, 1e20) -# Derived from Kaiyue's hyperparameter sweep: optimal_LR * hidden_size * sqrt(batch_size) -LR_CONSTANT = 0.33 - -# ---------------- WandB Metric Keys ---------------- -# These keys correspond to the metrics logged by Levanter's training callbacks. -THROUGHPUT_TOKENS_KEY = "throughput/total_tokens" -THROUGHPUT_GFLOPS_KEY = "throughput/total_gflops" -PARAMETER_COUNT_KEY = "parameter_count" -MODEL_CONFIG_KEY = "model" -TRAINER_CONFIG_KEY = "trainer" - # ---------------- Typed Tuples ---------------- @@ -107,6 +91,33 @@ class QuadraticFitCoeffs(NamedTuple): """Maximum token count used for fitting.""" +# ---------------- IsoFlopRecord ---------------- + + +@dataclass +class IsoFlopRecord: + """A single training run record for isoflop analysis. + + This is the contract between experiment code (which knows how to extract + these fields from raw metrics) and the analysis code (which just does math). + """ + + tokens: float + """Total tokens trained on.""" + + metric: float + """Evaluation metric value (e.g., bits-per-byte from Paloma).""" + + flops: float + """Total training FLOPs (bucketed).""" + + params: float + """Parameter count.""" + + label: str + """Experiment label for grouping (e.g., 'nemo', 'dclm').""" + + # ---------------- IsoFLOP Sweep Defaults ---------------- DEFAULT_SEQ_LEN = SEQ_LEN DEFAULT_STEPS_PER_RUN = 2**16 # Reference step count for hyperparameter tuning @@ -136,6 +147,57 @@ class CandidateConfig: flops_budget: float # Compute budget this config was generated for +class ModelConfiguration(Protocol): + """Protocol for model configs used in scaling law calculations. + + Any model config that implements flops_per_token can be used with the + scaling law functions. This allows the library to be model-agnostic + while still working with LlamaConfig, QwenConfig, etc. + """ + + def flops_per_token(self, vocab_size: int, seq_len: int) -> float: + """Return FLOPs per token for this model configuration.""" + ... + + +class ScalingRecipe(Protocol): + """Protocol defining the interface for scaling law recipes. + + Concrete implementations (e.g., Marin2025Recipe) should implement these + model-specific methods. Orchestration logic (generating training args, + predicting optimal configs) is handled by library functions that use + these core methods. + """ + + name: str + """Name identifying this recipe (e.g., 'marin-2025').""" + + def build_model_config( + self, target_params: int, vocab_size: int, seq_len: int = DEFAULT_SEQ_LEN + ) -> ModelConfiguration: + """Build a model config for a target parameter count.""" + ... + + def estimate_memory_bytes(self, model_config: ModelConfiguration, batch_size: int, vocab_size: int) -> int: + """Estimate memory usage in bytes for training with this model config.""" + ... + + def build_optimizer_config(self, candidate: CandidateConfig, vocab_size: int) -> OptimizerConfig: + """Build optimizer config for a candidate.""" + ... + + def candidate_configs( + self, + budget: float, + vocab_size: int, + seq_len: int = DEFAULT_SEQ_LEN, + steps_per_run: int = DEFAULT_STEPS_PER_RUN, + flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, + ) -> Iterator[CandidateConfig]: + """Yield candidate configurations within the FLOP budget.""" + ... + + @dataclass class IsoFlopTrainArgs: """Arguments needed to set up an isoflop training run. @@ -173,7 +235,6 @@ class MinimaRecord: optimal_tokens: float loss_at_optimal: float optimal_params: float - batch_size: int scaling_alpha: float | None = None scaling_A: float | None = None @@ -217,7 +278,7 @@ def round_flops_to_bucket(flops: float, base: float = 1.1) -> float: def compute_training_flops( - model_config: "LlamaConfig", + model_config: ModelConfiguration, vocab_size: int, batch_size: int, train_steps: int, @@ -230,7 +291,7 @@ def compute_training_flops( (see train_lm.py) and standard ML conventions (e.g., Chinchilla paper). Args: - model_config: Levanter model config with flops_per_token method (LlamaConfig or subclass). + model_config: Model config with flops_per_token method. vocab_size: Vocabulary size. batch_size: Training batch size. train_steps: Number of training steps. @@ -245,7 +306,7 @@ def compute_training_flops( def solve_for_batch_size( - model_config: "LlamaConfig", + model_config: ModelConfiguration, vocab_size: int, target_flops: float, train_steps: int, @@ -257,7 +318,7 @@ def solve_for_batch_size( Solve: batch = total_flops / (3 * flops_per_token * steps * seq_len) Args: - model_config: Levanter model config with flops_per_token method. + model_config: Model config with flops_per_token method. vocab_size: Vocabulary size. target_flops: Target total training FLOPs. train_steps: Number of training steps. @@ -271,7 +332,7 @@ def solve_for_batch_size( def solve_for_train_steps( - model_config: "LlamaConfig", + model_config: ModelConfiguration, vocab_size: int, target_flops: float, batch_size: int, @@ -283,7 +344,7 @@ def solve_for_train_steps( Solve: steps = total_flops / (3 * flops_per_token * batch * seq_len) Args: - model_config: Levanter model config with flops_per_token method. + model_config: Model config with flops_per_token method. vocab_size: Vocabulary size. target_flops: Target total training FLOPs. batch_size: Training batch size. @@ -324,33 +385,6 @@ def candidate_configs( yield from recipe.candidate_configs(budget, vocab_size, seq_len, steps_per_run, flop_tolerance) -def _minima_to_candidates( - minima_records: list[MinimaRecord], -) -> list[CandidateConfig]: - """Convert minima records to model-agnostic CandidateConfig objects. - - This is used by both run_isoflop_analysis_step() and run_isoflop_analysis() - to convert the fitted minima into usable candidate configs. - - Args: - minima_records: List of optimal configurations from scaling law fits. - """ - configs = [] - for rec in minima_records: - if rec.optimal_params == 0: - continue - configs.append( - CandidateConfig( - batch_size=rec.batch_size, - train_steps=int(rec.optimal_tokens / (rec.batch_size * SEQ_LEN)), - tokens=rec.optimal_tokens, - target_params=int(rec.optimal_params), - flops_budget=rec.flops, - ) - ) - return configs - - # ---------------- Training Args Generation ---------------- @@ -415,40 +449,6 @@ def generate_isoflop_train_args( return results -# ---------------- Helpers ---------------- - - -def parse_isoflop_run_name(run_name: str) -> str | None: - """Parse experiment name from isoflop run name. - - Supports two formats: - - New: isoflop-{budget}-N{params}-B{batch}-{experiment_name} - E.g., 'isoflop-1e+18-N1e+08-B128-nemo-wider-depth-adapt' - - Legacy: isoflop-{budget}-d{hidden}-L{layers}-B{batch}-{experiment_name} - E.g., 'isoflop-1e+19-d2048-L16-B1024-nemo-wider-depth-adapt' - - Optionally with a trailing - which is ignored. - - Returns experiment_name or None if parsing fails. - """ - # Strip optional - suffix - run_name = re.sub(r"-[0-9a-fA-F]{6}$", "", run_name) - - # New format: isoflop-{budget}-N{params}-B{batch}-{experiment_name} - new_pattern = r"isoflop-(?:[0-9.e+]+)-N(?:[0-9.e+]+)-B(?:\d+)-(.+)" - match = re.match(new_pattern, run_name) - if match: - return match.group(1) - - # Legacy format: isoflop-{budget}-d{hidden}-L{layers}-B{batch}-{experiment_name} - legacy_pattern = r"isoflop-(?:[0-9.e+]+)-d(?:\d+)-L(?:\d+)-B(?:\d+)-(.+)" - match = re.match(legacy_pattern, run_name) - if match: - return match.group(1) - - return None - - def robust_quad_logx(x: jnp.ndarray, y: jnp.ndarray, delta: float = 1.0) -> tuple[float, float, float]: """Fit a robust quadratic in log10(x) space using Huber loss. @@ -491,23 +491,24 @@ def objective(params): def fit_scaling_laws( - df: pd.DataFrame, + records: list[IsoFlopRecord], ) -> FitScalingLawsResult: - """ - Fit scaling laws and extract optimal configurations. + """Fit scaling laws and extract optimal configurations. Args: - df: DataFrame with columns: tokens, loss, flops, params, name, label + records: List of IsoFlopRecord with tokens, metric, flops, params, label, batch_size. Returns: FitScalingLawsResult containing minima_records, scaling_fits, and fit_curves. """ - if df is None or df.empty: + if not records: return FitScalingLawsResult(minima_records=[], scaling_fits={}, fit_curves={}) - datasets = list(dict.fromkeys(df["label"].tolist())) + # Get unique labels preserving order of first appearance + datasets = list(dict.fromkeys(r.label for r in records)) - buckets = sorted(df.flops.unique()) + # Get unique flop buckets + buckets = sorted(set(r.flops for r in records)) minima_records: list[MinimaRecord] = [] fit_curves: dict[tuple[str, float], QuadraticFitCoeffs] = {} @@ -515,16 +516,19 @@ def fit_scaling_laws( # Fit quadratic for each (label, budget) and find minima for lab in datasets: for C in buckets: - sub = df[(df.flops == C) & (df.label == lab)].sort_values("tokens") - if sub.empty: + sub = sorted( + [r for r in records if r.flops == C and r.label == lab], + key=lambda r: r.tokens, + ) + if not sub: continue # Robust quadratic fit in log10(tokens) # Use float64 to avoid int32 overflow for token counts > 2^31 - tokens_array = jnp.array(sub.tokens.values, dtype=jnp.float64) + tokens_array = jnp.array([r.tokens for r in sub], dtype=jnp.float64) a, b, c = robust_quad_logx( tokens_array, - jnp.array(sub.loss.values, dtype=jnp.float64), + jnp.array([r.metric for r in sub], dtype=jnp.float64), ) # Store coefficients along with token range used for fitting fit_curves[(lab, C)] = QuadraticFitCoeffs(a, b, c, float(tokens_array.min()), float(tokens_array.max())) @@ -534,28 +538,18 @@ def fit_scaling_laws( log_D_opt = -b / (2 * a) D_star = float(10**log_D_opt) - loss_opt = float(a * log_D_opt**2 + b * log_D_opt + c) - - idx = (sub.tokens - D_star).abs().argmin() - nearest_row = sub.iloc[idx] + metric_opt = float(a * log_D_opt**2 + b * log_D_opt + c) - # Require params to be present - the 6ND approximation is inaccurate for small models - params = nearest_row.get("params") - if params is None or pd.isna(params): - logger.warning( - f"Missing params for {lab} at {C:.1e} FLOPs - skipping. " - "Ensure runs log parameter_count or have full model config." - ) - continue + # Find record with tokens closest to optimal + nearest_record = min(sub, key=lambda r: abs(r.tokens - D_star)) minima_records.append( MinimaRecord( label=lab, flops=float(C), optimal_tokens=D_star, - loss_at_optimal=loss_opt, - optimal_params=float(params), - batch_size=int(nearest_row["batch_size"]), + loss_at_optimal=metric_opt, + optimal_params=nearest_record.params, ) ) @@ -591,97 +585,6 @@ def fit_scaling_laws( ) -def transform_metrics_for_isoflop( - df: pd.DataFrame, - metric_key: str, - label_map: dict[str, str] | None = None, -) -> pd.DataFrame: - """Transform raw metrics DataFrame into isoflop analysis format. - - Takes the generic metrics DataFrame from read_metrics_dataframe() and - transforms it into the format expected by the analysis: - columns: tokens, loss, flops, params, name, label - - The DataFrame contains nested 'config' and 'summary' dicts from tracker_metrics.jsonl. - - Args: - df: Raw metrics DataFrame from read_metrics_dataframe() - metric_key: Which metric column to use for loss (e.g., 'eval/paloma/c4_en/bpb') - label_map: Optional mapping from experiment_name -> display label - - Returns: - Transformed DataFrame ready for fit_scaling_laws() - """ - if df.empty: - return pd.DataFrame(columns=["tokens", "loss", "flops", "params", "name", "label"]) - - records = [] - for _, row in df.iterrows(): - run_path = row["run_path"] - run_name = extract_run_name_from_path(run_path) - - # Extract config and summary dicts - config = row.get("config", {}) or {} - summary = row.get("summary", {}) or {} - model_config = config.get(MODEL_CONFIG_KEY, {}) or {} - trainer_config = config.get(TRAINER_CONFIG_KEY, {}) or {} - - # Get tokens directly from summary - tokens = summary.get(THROUGHPUT_TOKENS_KEY) - if tokens is None or pd.isna(tokens): - logger.warning(f"Missing {THROUGHPUT_TOKENS_KEY} in summary for run {run_name}") - continue - - # Get total FLOPs from summary (convert GFLOPs to FLOPs) - total_gflops = summary.get(THROUGHPUT_GFLOPS_KEY) - if total_gflops is None or pd.isna(total_gflops): - logger.warning(f"Missing {THROUGHPUT_GFLOPS_KEY} in summary for run {run_name}") - continue - flops = round_flops_to_bucket(total_gflops * 1e9) - - if flops < 1e18: - continue - - # Get loss from summary[metric_key] - loss = summary.get(metric_key) - if loss is None or pd.isna(loss): - logger.warning(f"Missing metric {metric_key} for run {run_name}") - continue - - # Get parameter count from summary (required for accurate scaling analysis) - params = summary.get(PARAMETER_COUNT_KEY) - if params is None or pd.isna(params): - params = None - - # Get model architecture from config - hidden_dim = model_config.get("hidden_dim") - num_layers = model_config.get("num_layers") - batch_size = trainer_config.get("train_batch_size") - - # Determine experiment name and label from run name - exp_name = parse_isoflop_run_name(run_name) or run_name - if label_map and exp_name in label_map: - label = label_map[exp_name] - else: - label = exp_name - - records.append( - dict( - tokens=tokens, - loss=loss, - flops=flops, - params=params, - hidden_dim=hidden_dim, - num_layers=num_layers, - batch_size=batch_size, - name=run_name, - label=label, - ) - ) - - return pd.DataFrame.from_records(records) - - # ---------------- Predict Optimal Config ---------------- @@ -787,173 +690,3 @@ def predict_optimal_configs_for_budgets( ) configs.append(config) return configs - - -# ---------------- Result Dataclass ---------------- - - -@dataclass -class IsoFlopAnalysisResult: - """Result from scaling ladder analysis containing optimal configs and analysis data.""" - - configs: list[CandidateConfig] - """List of optimal CandidateConfig for each (label, flops_budget) pair.""" - - scaling_fits: dict[str, ScalingFit] - """Per-label scaling fits: {label: ScalingFit} for N* ~ A * C^alpha.""" - - isoflop_df: pd.DataFrame - """Transformed dataframe used for analysis.""" - - minima_records: list[MinimaRecord] - """Raw minima records with detailed info for each optimum.""" - - fit_curves: dict[tuple[str, float], QuadraticFitCoeffs] - """Quadratic fit coefficients {(label, flops): QuadraticFitCoeffs} for plotting.""" - - def to_json_dict(self) -> dict: - """Convert result to JSON-serializable dict (excludes DataFrame and fit_curves).""" - return { - "configs": [asdict(c) for c in self.configs], - "scaling_fits": {k: list(v) for k, v in self.scaling_fits.items()}, - "minima_records": [asdict(r) for r in self.minima_records], - } - - -# ---------------- ExecutorStep Config ---------------- - - -@dataclass(frozen=True) -class IsoFlopAnalysisConfig(EvalMetricsAnalysisConfig): - """Configuration for scaling ladder analysis ExecutorStep.""" - - recipe: ScalingRecipe = field(kw_only=True) - """Scaling recipe for computing optimal hyperparameters.""" - - metric_key: str = field(default=DEFAULT_EVAL_METRIC_KEY, kw_only=True) - """Metric to use for loss (default: eval/paloma/c4_en/bpb - Paloma benchmark on C4 English).""" - - label_map: tuple[tuple[str, str], ...] | None = field(default=None, kw_only=True) - """Optional mapping from experiment_name -> display label as tuple of pairs.""" - - -def run_isoflop_analysis_step(config: IsoFlopAnalysisConfig) -> None: - """Execute scaling ladder analysis (called by ExecutorStep).""" - raw_df = read_metrics_dataframe(config) - - if raw_df.empty: - logger.warning("No eval metrics found in training runs") - return - - label_map = dict(config.label_map) if config.label_map else None - isoflop_df = transform_metrics_for_isoflop(raw_df, config.metric_key, label_map) - - if isoflop_df.empty: - logger.warning("No valid isoflop data after transformation") - return - - logger.info(f"Loaded {len(isoflop_df)} runs for scaling ladder analysis") - logger.info(f"Labels found: {isoflop_df['label'].unique().tolist()}") - logger.info(f"FLOP budgets: {sorted(isoflop_df['flops'].unique())}") - - fit_result = fit_scaling_laws(isoflop_df) - - logger.info(f"Found {len(fit_result.minima_records)} optimal configurations") - for label, (alpha, A) in fit_result.scaling_fits.items(): - logger.info(f" {label}: N* = {A:.2e} * C^{alpha:.3f}") - - configs = _minima_to_candidates(fit_result.minima_records) - - result = IsoFlopAnalysisResult( - configs=configs, - scaling_fits=fit_result.scaling_fits, - isoflop_df=isoflop_df, - minima_records=fit_result.minima_records, - fit_curves=fit_result.fit_curves, - ) - - fs, _, _ = fsspec.get_fs_token_paths(config.output_path) - fs.makedirs(config.output_path, exist_ok=True) - - result_path = os.path.join(config.output_path, "isoflop_analysis_result.json") - with fs.open(result_path, "w") as f: - json.dump(result.to_json_dict(), f, indent=2) - logger.info(f"Saved results to {result_path}") - - # Also save the full dataframe and fit curves for downstream plotting - df_path = os.path.join(config.output_path, "isoflop_df.parquet") - isoflop_df.to_parquet(df_path) - logger.info(f"Saved dataframe to {df_path}") - - fit_curves_path = os.path.join(config.output_path, "fit_curves.json") - # Convert tuple keys to strings for JSON serialization - fit_curves_json = {f"{label}|{flops}": list(coeffs) for (label, flops), coeffs in result.fit_curves.items()} - with fs.open(fit_curves_path, "w") as f: - json.dump(fit_curves_json, f, indent=2) - logger.info(f"Saved fit curves to {fit_curves_path}") - - -# ---------------- Programmatic Interface ---------------- - - -def run_isoflop_analysis( - training_runs: Sequence[str], - recipe: ScalingRecipe, - metric_key: str = DEFAULT_EVAL_METRIC_KEY, - label_map: dict[str, str] | None = None, -) -> IsoFlopAnalysisResult: - """Analyze isoflop training runs and return optimal training configurations. - - This is the programmatic interface for scaling ladder analysis, useful for - notebooks or scripts. For ExecutorStep-based pipelines, use - `run_isoflop_analysis_step()` with `IsoFlopAnalysisConfig`. - - Args: - training_runs: List of path strings to training run output directories - metric_key: Which metric to use for loss (default: eval/paloma/c4_en/bpb) - label_map: Optional mapping from experiment_name -> display label - recipe: ScalingRecipe with hyperparameter settings - - Returns: - IsoFlopAnalysisResult with configs, scaling_fits, and analysis data - """ - config = EvalMetricsAnalysisConfig( - training_runs=training_runs, - output_path="analysis/scaling_ladder", - ) - raw_df = read_metrics_dataframe(config) - - if raw_df.empty: - logger.warning("No eval metrics found") - return IsoFlopAnalysisResult( - configs=[], - scaling_fits={}, - isoflop_df=pd.DataFrame(), - minima_records=[], - fit_curves={}, - ) - - isoflop_df = transform_metrics_for_isoflop(raw_df, metric_key, label_map) - - if isoflop_df.empty: - logger.warning("No valid isoflop data after transformation") - return IsoFlopAnalysisResult( - configs=[], - scaling_fits={}, - isoflop_df=pd.DataFrame(), - minima_records=[], - fit_curves={}, - ) - - logger.info(f"Transformed {len(isoflop_df)} runs for scaling ladder analysis") - - fit_result = fit_scaling_laws(isoflop_df) - configs = _minima_to_candidates(fit_result.minima_records) - - return IsoFlopAnalysisResult( - configs=configs, - scaling_fits=fit_result.scaling_fits, - isoflop_df=isoflop_df, - minima_records=fit_result.minima_records, - fit_curves=fit_result.fit_curves, - ) diff --git a/lib/marin/src/marin/scaling_laws/recipe.py b/lib/marin/src/marin/scaling_laws/recipe.py deleted file mode 100644 index e22596b2be..0000000000 --- a/lib/marin/src/marin/scaling_laws/recipe.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright 2025 The Marin Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Scaling recipes: model-specific hyperparameter bundles for scaling law experiments. - -A ScalingRecipe defines the interface for scaling experiments. Concrete implementations -provide model-specific decisions for: -- Architecture formula (how to compute architecture from target param count) -- Model config building (returns LlamaConfig or subclass) -- Optimizer config building -- Candidate generation for isoflop sweeps - -Orchestration logic (generating train args, predicting optimal configs) lives in -the library functions in isoflop_analysis.py, not in recipes. -""" - -from collections.abc import Iterator -from typing import TYPE_CHECKING, Protocol - -from levanter.models.llama import LlamaConfig -from levanter.optim.config import OptimizerConfig - -if TYPE_CHECKING: - from marin.scaling_laws.isoflop_analysis import CandidateConfig - -# Default constants -DEFAULT_SEQ_LEN = 4096 -DEFAULT_STEPS_PER_RUN = 2**16 # Reference step count for hyperparameter tuning -DEFAULT_FLOP_TOLERANCE = 0.01 # Relative error tolerance for FLOP budget - - -class ScalingRecipe(Protocol): - """Protocol defining the interface for scaling law recipes. - - Concrete implementations (e.g., Marin2025Recipe) should implement these - model-specific methods. Orchestration logic (generating training args, - predicting optimal configs) is handled by library functions that use - these core methods. - """ - - name: str - """Name identifying this recipe (e.g., 'marin-2025').""" - - def build_model_config(self, target_params: int, vocab_size: int, seq_len: int = DEFAULT_SEQ_LEN) -> LlamaConfig: - """Build a model config for a target parameter count. - - TODO: LlamaConfig is currently our most generic config type, and we - subclass it to other models (e.g. Qwen, OLMo, etc). We should make - a true generic config class eventually. - """ - ... - - def estimate_memory_bytes(self, model_config: LlamaConfig, batch_size: int, vocab_size: int) -> int: - """Estimate memory usage in bytes for training with this model config.""" - ... - - def build_optimizer_config(self, candidate: "CandidateConfig", vocab_size: int) -> OptimizerConfig: - """Build optimizer config for a candidate.""" - ... - - def candidate_configs( - self, - budget: float, - vocab_size: int, - seq_len: int = DEFAULT_SEQ_LEN, - steps_per_run: int = DEFAULT_STEPS_PER_RUN, - flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, - ) -> "Iterator[CandidateConfig]": - """Yield candidate configurations within the FLOP budget.""" - ... diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py index 42988c9479..f2230b29a2 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -38,10 +38,10 @@ from marin.processing.tokenize import get_vocab_size_for_tokenizer from marin.scaling_laws.isoflop_analysis import ( ScalingFit, + ScalingRecipe, predict_optimal_config, ) from marin.scaling_laws.tpu_utils import pick_v5p_type -from marin.scaling_laws.recipe import ScalingRecipe from marin.training.training import TrainLmOnPodConfig, run_levanter_train_lm logger = logging.getLogger(__name__) diff --git a/lib/marin/src/marin/scaling_laws/tpu_utils.py b/lib/marin/src/marin/scaling_laws/tpu_utils.py index b176d45b75..3e793bdb23 100644 --- a/lib/marin/src/marin/scaling_laws/tpu_utils.py +++ b/lib/marin/src/marin/scaling_laws/tpu_utils.py @@ -22,8 +22,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from marin.scaling_laws.isoflop_analysis import CandidateConfig - from marin.scaling_laws.recipe import ScalingRecipe + from marin.scaling_laws.isoflop_analysis import CandidateConfig, ScalingRecipe # ---------------- TPU v5p Hardware Constants ---------------- # These constants are specific to TPU v5p pods. diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py index fc4b2e8b6e..d90d019755 100644 --- a/tests/test_scaling_laws.py +++ b/tests/test_scaling_laws.py @@ -19,28 +19,25 @@ """ import jax.numpy as jnp -import pandas as pd from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig from levanter.models.qwen import Qwen3Config from marin.scaling_laws.isoflop_analysis import ( DEFAULT_SEQ_LEN, - MARIN_TOKENIZER_VOCAB_SIZE, IsoFlopTrainArgs, + MARIN_TOKENIZER_VOCAB_SIZE, candidate_configs, compute_training_flops, fit_scaling_laws, generate_isoflop_train_args, - parse_isoflop_run_name, robust_quad_logx, solve_for_batch_size, solve_for_train_steps, - transform_metrics_for_isoflop, ) -# Import the concrete recipe from experiments for testing -from experiments.isoflop_sweep import Marin2025Recipe +# Import the concrete recipe and transform function from experiments +from experiments.isoflop_sweep import Marin2025Recipe, parse_isoflop_run_name, transform_levanter_metrics # --- FLOP computation tests --- @@ -241,26 +238,31 @@ def test_end_to_end_analysis_pipeline(): Uses SAMPLE_METRICS_DATA (simulating real wandb metrics) to verify the full pipeline: metrics transformation -> curve fitting -> scaling law extraction. """ - raw_df = pd.DataFrame(SAMPLE_METRICS_DATA) + from marin.scaling_laws import round_flops_to_bucket - # Transform metrics - isoflop_df = transform_metrics_for_isoflop(raw_df, "eval/paloma/c4_en/bpb") - assert len(isoflop_df) == 6 + # Transform metrics using the Levanter transform function + records = transform_levanter_metrics(SAMPLE_METRICS_DATA, "eval/paloma/c4_en/bpb") + assert len(records) == 6 # Fit scaling laws - fit_result = fit_scaling_laws(isoflop_df) + fit_result = fit_scaling_laws(records) - # Should find two minima (one per budget: 1e18 and 1e19) + # Should find two minima (one per budget: ~1e18 and ~1e19) + # FLOP values are bucketed by round_flops_to_bucket assert len(fit_result.minima_records) == 2 - assert {rec.flops for rec in fit_result.minima_records} == {1e18, 1e19} + + # Get expected bucketed values + bucket_1e18 = round_flops_to_bucket(1e18) + bucket_1e19 = round_flops_to_bucket(1e19) + assert {rec.flops for rec in fit_result.minima_records} == {bucket_1e18, bucket_1e19} # Verify fitted minima are near expected optimal points minima_by_flops = {rec.flops: rec for rec in fit_result.minima_records} - # At 1e18: raw data optimal at 2.5B tokens (loss=1.12) - assert abs(minima_by_flops[1e18].optimal_tokens - 2.6e9) < 0.2e9 - assert abs(minima_by_flops[1e18].loss_at_optimal - 1.12) < 0.01 + # At ~1e18: raw data optimal at 2.5B tokens (loss=1.12) + assert abs(minima_by_flops[bucket_1e18].optimal_tokens - 2.6e9) < 0.2e9 + assert abs(minima_by_flops[bucket_1e18].loss_at_optimal - 1.12) < 0.01 - # At 1e19: raw data optimal at 8B tokens (loss=0.98) - assert abs(minima_by_flops[1e19].optimal_tokens - 8.8e9) < 0.2e9 - assert abs(minima_by_flops[1e19].loss_at_optimal - 0.98) < 0.01 + # At ~1e19: raw data optimal at 8B tokens (loss=0.98) + assert abs(minima_by_flops[bucket_1e19].optimal_tokens - 8.8e9) < 0.2e9 + assert abs(minima_by_flops[bucket_1e19].loss_at_optimal - 0.98) < 0.01 From 54181312abce90a4c8d6dd6a85377a193f606096 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Fri, 9 Jan 2026 16:40:48 -0800 Subject: [PATCH 66/79] Try to remove a lot of the batch_size dependency --- experiments/isoflop_sweep.py | 74 +++++++++++---- lib/marin/src/marin/scaling_laws/__init__.py | 2 - .../marin/scaling_laws/eval_metrics_reader.py | 33 ++----- .../marin/scaling_laws/isoflop_analysis.py | 91 +++++++++---------- .../src/marin/scaling_laws/scaling_ladder.py | 13 ++- lib/marin/src/marin/scaling_laws/tpu_utils.py | 15 ++- tests/test_scaling_laws.py | 20 ++-- 7 files changed, 136 insertions(+), 112 deletions(-) diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index 283a942342..4a4ad9b7ab 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -349,15 +349,19 @@ def _build_model_config_from_hidden_size(self, hidden_size: int, seq_len: int = def estimate_memory_bytes( self, - model_config: LlamaConfig, - batch_size: int, + candidate: CandidateConfig, vocab_size: int, + seq_len: int = DEFAULT_SEQ_LEN, optim_mult: int = 3, dtype_size: int = 4, fudge_factor: float = 2.0, ) -> int: """Estimate float32 memory usage in bytes for training.""" - param_count = self._compute_params_for_hidden_size(model_config.hidden_dim, vocab_size) + hidden_size = self.hidden_size_for_params(candidate.target_params, vocab_size) + model_config = self._build_model_config_from_hidden_size(hidden_size, seq_len) + batch_size, _ = self.compute_training_schedule(candidate, vocab_size, seq_len) + + param_count = self._compute_params_for_hidden_size(hidden_size, vocab_size) param_bytes = param_count * optim_mult * dtype_size act_bytes = (batch_size * model_config.max_seq_len) * ( (model_config.hidden_dim * model_config.num_layers) + vocab_size * fudge_factor @@ -365,11 +369,40 @@ def estimate_memory_bytes( total_bytes = param_bytes + act_bytes return int(total_bytes * fudge_factor) - def build_optimizer_config(self, candidate: CandidateConfig, vocab_size: int) -> OptimizerConfig: + def compute_training_schedule( + self, candidate: CandidateConfig, vocab_size: int, seq_len: int = DEFAULT_SEQ_LEN + ) -> tuple[int, int]: + """Compute training schedule (batch_size, train_steps) for a candidate.""" + hidden_size = self.hidden_size_for_params(candidate.target_params, vocab_size) + + # Start with batch_size that gives us ~DEFAULT_STEPS_PER_RUN steps for the tokens + target_steps = DEFAULT_STEPS_PER_RUN + batch_exact = candidate.tokens / (target_steps * seq_len) + batch_size = _round_to_power_of_two(batch_exact) + + # Adjust batch_size to respect learning rate constraints + lr = self._compute_learning_rate(batch_size, hidden_size) + while lr > self.max_learning_rate and batch_size >= self.min_batch_size * 2: + batch_size //= 2 + lr = self._compute_learning_rate(batch_size, hidden_size) + + # Ensure minimum batch size + if batch_size < self.min_batch_size: + batch_size = self.min_batch_size + + # Compute train_steps to achieve target tokens + train_steps = round(candidate.tokens / (batch_size * seq_len)) + + return (batch_size, train_steps) + + def build_optimizer_config( + self, candidate: CandidateConfig, vocab_size: int, seq_len: int = DEFAULT_SEQ_LEN + ) -> OptimizerConfig: """Build optimizer config for a candidate.""" + batch_size, _ = self.compute_training_schedule(candidate, vocab_size, seq_len) hidden_size = self.hidden_size_for_params(candidate.target_params, vocab_size) - learning_rate = self._compute_learning_rate(candidate.batch_size, hidden_size) - beta2 = self._compute_beta2(candidate.batch_size) + learning_rate = self._compute_learning_rate(batch_size, hidden_size) + beta2 = self._compute_beta2(batch_size) return CautiousConfig( learning_rate=learning_rate, @@ -393,7 +426,12 @@ def candidate_configs( steps_per_run: int = DEFAULT_STEPS_PER_RUN, flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, ) -> Iterator[CandidateConfig]: - """Yield candidate configurations within the FLOP budget.""" + """Yield candidate configurations within the FLOP budget. + + Iterates over hidden sizes, computes batch_size to hit the FLOP budget, + validates constraints (LR, min batch size), and yields CandidateConfigs + for valid configurations. + """ step_size = self._get_step_size(budget) min_hidden = 2**self.min_hidden_pow max_hidden = 2**self.max_hidden_pow @@ -401,9 +439,11 @@ def candidate_configs( for hidden_size in range(min_hidden, max_hidden + 1, step_size): model_config = self._build_model_config_from_hidden_size(hidden_size, seq_len) + # Compute batch_size to hit FLOP budget batch_exact = solve_for_batch_size(model_config, vocab_size, budget, steps_per_run, seq_len) batch_size = _round_to_power_of_two(batch_exact) + # Adjust batch_size to respect learning rate constraints lr = self._compute_learning_rate(batch_size, hidden_size) while lr > self.max_learning_rate: batch_size //= 2 @@ -414,6 +454,7 @@ def candidate_configs( train_steps = round(solve_for_train_steps(model_config, vocab_size, budget, batch_size, seq_len)) + # Validate achieved FLOPs are within tolerance achieved_flops = 3 * model_config.flops_per_token(vocab_size, seq_len) * batch_size * train_steps * seq_len if abs(achieved_flops - budget) / budget > flop_tolerance: continue @@ -421,11 +462,10 @@ def candidate_configs( tokens = batch_size * train_steps * seq_len target_params = self._compute_params_for_hidden_size(hidden_size, vocab_size) + # Yield simplified CandidateConfig (without batch_size/train_steps) yield CandidateConfig( - batch_size=batch_size, - train_steps=train_steps, - tokens=tokens, target_params=target_params, + tokens=tokens, flops_budget=budget, ) @@ -463,9 +503,6 @@ class IsoFlopAnalysisConfig: metrics_filename: str = "tracker_metrics.jsonl" """Name of the metrics file within each checkpoint directory.""" - backfill_from_wandb: bool = True - """If True, backfill tracker_metrics.jsonl from WandB for runs that completed before this feature.""" - wandb_entity_project: str = f"{WANDB_ENTITY}/{WANDB_PROJECT}" """WandB entity/project to query for backfill (format: 'entity/project').""" @@ -605,24 +642,27 @@ def create_isoflop_sweep_steps( # Build model and optimizer configs using the recipe model_config = recipe.build_model_config(candidate.target_params, vocab_size, seq_len) - optimizer_config = recipe.build_optimizer_config(candidate, vocab_size) + optimizer_config = recipe.build_optimizer_config(candidate, vocab_size, seq_len) tpu_type = pick_v5p_type(candidate, vocab_size, seq_len, recipe) + # Compute training schedule from recipe + batch_size, num_steps = recipe.compute_training_schedule(candidate, vocab_size, seq_len) + # Use local naming with architecture details for backward compatibility run_name = _format_run_name( candidate.flops_budget, model_config.hidden_dim, model_config.num_layers, - candidate.batch_size, + batch_size, experiment_name, ) output_path = f"checkpoints/isoflop/{run_name}" train_cfg = replace( base_train_config, - train_batch_size=candidate.batch_size, + train_batch_size=batch_size, learning_rate=optimizer_config.learning_rate, - num_train_steps=candidate.train_steps, + num_train_steps=num_steps, resources=ResourceConfig.with_tpu(tpu_type), optimizer_config=optimizer_config, ) diff --git a/lib/marin/src/marin/scaling_laws/__init__.py b/lib/marin/src/marin/scaling_laws/__init__.py index e76f1c0734..1d9b09d112 100644 --- a/lib/marin/src/marin/scaling_laws/__init__.py +++ b/lib/marin/src/marin/scaling_laws/__init__.py @@ -27,7 +27,6 @@ QuadraticFitCoeffs, ScalingFit, ScalingRecipe, - candidate_configs, compute_training_flops, fit_scaling_laws, generate_isoflop_train_args, @@ -70,7 +69,6 @@ "ScalingLadderRungConfig", "ScalingRecipe", # Functions - "candidate_configs", "compute_training_flops", "create_isoflop_plot", "create_scaling_plot", diff --git a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py index 1f3132f55f..20cf3603f8 100644 --- a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py +++ b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py @@ -96,15 +96,9 @@ class EvalMetricsAnalysisConfig: training_runs: Sequence[str] """List of training run output paths (executor resolves InputName to str at runtime).""" - output_path: str - """Where to write analysis outputs.""" - metrics_filename: str = "tracker_metrics.jsonl" """Name of the metrics file within each checkpoint directory.""" - backfill_from_wandb: bool = True - """If True, backfill tracker_metrics.jsonl from WandB for runs that completed before this feature.""" - wandb_entity_project: str = f"{WANDB_ENTITY}/{WANDB_PROJECT}" """WandB entity/project to query for backfill (format: 'entity/project').""" @@ -113,7 +107,7 @@ def read_raw_records(config: EvalMetricsAnalysisConfig) -> list[dict]: """Read raw eval metrics from training runs. This is the shared utility that all analysis subtypes use to load metrics. - It handles reading JSONL files and optional WandB backfill. + It handles reading JSONL files and WandB backfill when files are missing. Args: config: Analysis config with training_runs and backfill settings @@ -129,23 +123,16 @@ def read_raw_records(config: EvalMetricsAnalysisConfig) -> list[dict]: fs, _, _ = fsspec.get_fs_token_paths(metrics_file) if not fs.exists(metrics_file): - logger.info(f"{metrics_file} does not exist") - - if config.backfill_from_wandb: - logger.info("Attempting to backfill from WandB...") - - success = _backfill_metrics_from_wandb( - checkpoint_path=run_path, - metrics_file=metrics_file, - entity_project=config.wandb_entity_project, - ) - if not success: - raise RuntimeError( - f"Backfill from WandB failed for run {i} (path={run_path}, metrics_file={metrics_file})" - ) - else: + logger.info(f"{metrics_file} does not exist, attempting to backfill from WandB...") + + success = _backfill_metrics_from_wandb( + checkpoint_path=run_path, + metrics_file=metrics_file, + entity_project=config.wandb_entity_project, + ) + if not success: raise RuntimeError( - f"Metrics file missing for run {i} (path={run_path}), and backfill_from_wandb is disabled" + f"Backfill from WandB failed for run {i} (path={run_path}, metrics_file={metrics_file})" ) with fs.open(metrics_file, "r") as f: diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index 68feb33ce0..b371d6705a 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -134,17 +134,19 @@ class CandidateConfig: Contains only the fundamental parameters that scaling laws reason about: - How much compute (flops_budget) - How to allocate it between model size (target_params) and data (tokens) - - Training batch configuration (batch_size, train_steps) - All model-specific details (architecture, optimizer hyperparameters) are - computed by the ScalingRecipe from these values. + The training schedule (batch_size, train_steps) is computed by the + ScalingRecipe at training time via compute_training_schedule(). """ - batch_size: int - train_steps: int - tokens: float # = batch_size * train_steps * seq_len - target_params: int # Optimal parameter count for this flops_budget - flops_budget: float # Compute budget this config was generated for + target_params: int + """Optimal parameter count for this flops_budget.""" + + tokens: float + """Total tokens to train on.""" + + flops_budget: float + """Compute budget this config was generated for.""" class ModelConfiguration(Protocol): @@ -178,11 +180,18 @@ def build_model_config( """Build a model config for a target parameter count.""" ... - def estimate_memory_bytes(self, model_config: ModelConfiguration, batch_size: int, vocab_size: int) -> int: - """Estimate memory usage in bytes for training with this model config.""" + def estimate_memory_bytes(self, candidate: CandidateConfig, vocab_size: int, seq_len: int = DEFAULT_SEQ_LEN) -> int: + """Estimate memory usage in bytes for training a candidate configuration. + + The implementation can access candidate.target_params, candidate.tokens, and + candidate.flops_budget to compute memory requirements. This allows the recipe + to compute the actual batch_size (from tokens) when estimating memory. + """ ... - def build_optimizer_config(self, candidate: CandidateConfig, vocab_size: int) -> OptimizerConfig: + def build_optimizer_config( + self, candidate: CandidateConfig, vocab_size: int, seq_len: int = DEFAULT_SEQ_LEN + ) -> OptimizerConfig: """Build optimizer config for a candidate.""" ... @@ -194,7 +203,17 @@ def candidate_configs( steps_per_run: int = DEFAULT_STEPS_PER_RUN, flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, ) -> Iterator[CandidateConfig]: - """Yield candidate configurations within the FLOP budget.""" + """Yield candidate configurations within the FLOP budget. + + A typical implementation will iterate over hidden sizes (which determine + target_params), compute the batch_size needed to hit the FLOP budget, + and yield configs where the relative FLOP error is within tolerance. + + The implementation should handle model-specific constraints like: + - Hidden size increments (e.g., multiples of 64 or 128) + - Memory constraints affecting maximum batch size + - Architecture-specific parameter count formulas + """ ... @@ -204,7 +223,7 @@ class IsoFlopTrainArgs: This dataclass contains the model-agnostic parameters needed for training. The ScalingRecipe is responsible for converting these to model-specific - configs (model architecture, optimizer hyperparameters). + configs (model architecture, optimizer hyperparameters, training schedule). Naming (run_name, output_path) is intentionally not included here - that's the responsibility of experiment code which may have its own conventions. @@ -213,11 +232,12 @@ class IsoFlopTrainArgs: >>> args = generate_isoflop_train_args(budgets, vocab_size, recipe)[0] >>> # Recipe converts candidate to model-specific configs >>> model_config = recipe.build_model_config(args.candidate.target_params, vocab_size) - >>> optimizer_config = recipe.build_optimizer_config(args.candidate) + >>> batch_size, train_steps = recipe.compute_training_schedule(args.candidate, vocab_size) + >>> optimizer_config = recipe.build_optimizer_config(args.candidate, batch_size, vocab_size) """ candidate: CandidateConfig - """Model-agnostic compute allocation (batch_size, train_steps, tokens, target_params).""" + """Model-agnostic compute allocation (target_params, tokens, flops_budget).""" tags: tuple[str, ...] """Tags for tracking/filtering runs.""" @@ -357,34 +377,6 @@ def solve_for_train_steps( return target_flops / (3 * flops_per_token * batch_size * seq_len) -def candidate_configs( - budget: float, - vocab_size: int, - recipe: ScalingRecipe, - seq_len: int = DEFAULT_SEQ_LEN, - steps_per_run: int = DEFAULT_STEPS_PER_RUN, - flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, -) -> Iterator[CandidateConfig]: - """Yield candidate model configurations within the FLOP budget. - - This is a convenience function that delegates to recipe.candidate_configs(). - The recipe encapsulates all model-specific decisions (architecture formula, - search bounds, constraints), while this function provides backward compatibility. - - Args: - budget: Target FLOP budget. - vocab_size: Vocabulary size for the tokenizer. - recipe: ScalingRecipe with architecture/hyperparameter settings. - seq_len: Sequence length for training. - steps_per_run: Reference step count for FLOP budget calculation. - flop_tolerance: Tolerance for matching FLOP budget (relative error). - - Yields: - CandidateConfig objects for each valid configuration. - """ - yield from recipe.candidate_configs(budget, vocab_size, seq_len, steps_per_run, flop_tolerance) - - # ---------------- Training Args Generation ---------------- @@ -425,17 +417,21 @@ def generate_isoflop_train_args( >>> for args in train_args: ... # Recipe converts model-agnostic candidate to model-specific configs ... model_config = recipe.build_model_config(args.candidate.target_params, vocab_size) - ... optimizer_config = recipe.build_optimizer_config(args.candidate, vocab_size) + ... batch_size, train_steps = recipe.compute_training_schedule(args.candidate, vocab_size) + ... optimizer_config = recipe.build_optimizer_config(args.candidate, batch_size, vocab_size) """ results: list[IsoFlopTrainArgs] = [] for budget in budgets: for candidate in recipe.candidate_configs(budget, vocab_size, seq_len, steps_per_run, flop_tolerance): + # Compute training schedule from recipe (for tags) + batch_size, train_steps = recipe.compute_training_schedule(candidate, vocab_size, seq_len) + tags = ( f"FLOPs={budget:.1e}", f"N={candidate.target_params:.1e}", - f"B={candidate.batch_size}", - f"steps={candidate.train_steps}", + f"B={batch_size}", + f"steps={train_steps}", f"tokens={candidate.tokens:.1e}", ) @@ -642,8 +638,7 @@ def predict_optimal_config( best = max(candidates, key=lambda c: c.tokens) logger.info( - f"Selected config: N={best.target_params:.2e}, " - f"B={best.batch_size}, tokens={best.tokens:.2e} (optimal: {optimal_tokens:.2e})" + f"Selected config: N={best.target_params:.2e}, " f"tokens={best.tokens:.2e} (optimal: {optimal_tokens:.2e})" ) return best diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py index f2230b29a2..82d7805f52 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -88,7 +88,8 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: The recipe handles all model-specific decisions: - Model config is built via `recipe.build_model_config(target_params, vocab_size)` - - Optimizer config is built via `recipe.build_optimizer_config(candidate, vocab_size)` + - Training schedule is built via `recipe.compute_training_schedule(candidate, vocab_size)` + - Optimizer config is built via `recipe.build_optimizer_config(candidate, batch_size, vocab_size)` """ result_path = os.path.join(config.analysis_output_path, "isoflop_analysis_result.json") fs, _, _ = fsspec.get_fs_token_paths(result_path) @@ -121,14 +122,16 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: logger.info( f"Training with optimal config for {config.target_budget:.2e} FLOPs:\n" f" target_params={candidate.target_params:.2e}\n" - f" batch_size={candidate.batch_size}, train_steps={candidate.train_steps}\n" f" tokens={candidate.tokens:.2e}" ) model_cfg = config.recipe.build_model_config(candidate.target_params, vocab_size, config.seq_len) - optimizer_cfg = config.recipe.build_optimizer_config(candidate, vocab_size) + optimizer_cfg = config.recipe.build_optimizer_config(candidate, vocab_size, config.seq_len) tpu_type = pick_v5p_type(candidate, vocab_size, config.seq_len, config.recipe) + # Compute training schedule - recipe-specific, not in protocol + batch_size, train_steps = config.recipe.compute_training_schedule(candidate, vocab_size, config.seq_len) + train_config = TrainLmConfig( data=config.tokenized, trainer=TrainerConfig( @@ -142,8 +145,8 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: ], ), mp=jmp.get_policy("p=f32,c=bfloat16"), - train_batch_size=candidate.batch_size, - num_train_steps=candidate.train_steps, + train_batch_size=batch_size, + num_train_steps=train_steps, steps_per_eval=1000, checkpointer=CheckpointerConfig( save_interval=timedelta(minutes=10), diff --git a/lib/marin/src/marin/scaling_laws/tpu_utils.py b/lib/marin/src/marin/scaling_laws/tpu_utils.py index 3e793bdb23..2a95b63c6f 100644 --- a/lib/marin/src/marin/scaling_laws/tpu_utils.py +++ b/lib/marin/src/marin/scaling_laws/tpu_utils.py @@ -19,10 +19,8 @@ """ import math -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from marin.scaling_laws.isoflop_analysis import CandidateConfig, ScalingRecipe +from marin.scaling_laws.isoflop_analysis import CandidateConfig, ScalingRecipe # ---------------- TPU v5p Hardware Constants ---------------- # These constants are specific to TPU v5p pods. @@ -38,18 +36,18 @@ def pick_v5p_type( - candidate: "CandidateConfig", + candidate: CandidateConfig, vocab_size: int, seq_len: int, - recipe: "ScalingRecipe", + recipe: ScalingRecipe, ) -> str: """Select the smallest TPU v5p slice that fits the model in float32. Args: - candidate: CandidateConfig with target_params and batch_size. + candidate: CandidateConfig with target_params and tokens. vocab_size: Vocabulary size. seq_len: Sequence length. - recipe: ScalingRecipe to determine architecture. + recipe: ScalingRecipe for memory estimation. Returns: TPU slice name, e.g., "v5p-8" or "v5p-32". @@ -57,8 +55,7 @@ def pick_v5p_type( Raises: ValueError: If the model is too large for available v5p slices. """ - model_config = recipe.build_model_config(candidate.target_params, vocab_size, seq_len) - need_bytes = recipe.estimate_memory_bytes(model_config, candidate.batch_size, vocab_size) + need_bytes = recipe.estimate_memory_bytes(candidate, vocab_size, seq_len) chip_bytes = HBM_PER_CHIP_GIB * 1024**3 chips = math.ceil(need_bytes / chip_bytes) cores_req = chips * CORES_PER_CHIP diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py index d90d019755..edfb89c934 100644 --- a/tests/test_scaling_laws.py +++ b/tests/test_scaling_laws.py @@ -27,7 +27,6 @@ DEFAULT_SEQ_LEN, IsoFlopTrainArgs, MARIN_TOKENIZER_VOCAB_SIZE, - candidate_configs, compute_training_flops, fit_scaling_laws, generate_isoflop_train_args, @@ -99,13 +98,17 @@ def test_candidate_configs_within_tolerance(): flop_tolerance = 0.01 seq_len = DEFAULT_SEQ_LEN - for candidate in candidate_configs(budget, MARIN_TOKENIZER_VOCAB_SIZE, recipe, flop_tolerance=flop_tolerance): + for candidate in recipe.candidate_configs( + budget, MARIN_TOKENIZER_VOCAB_SIZE, seq_len, flop_tolerance=flop_tolerance + ): + # Compute training schedule from recipe + batch_size, train_steps = recipe.compute_training_schedule(candidate, MARIN_TOKENIZER_VOCAB_SIZE, seq_len) model_config = recipe.build_model_config(candidate.target_params, MARIN_TOKENIZER_VOCAB_SIZE, seq_len) achieved = compute_training_flops( model_config, MARIN_TOKENIZER_VOCAB_SIZE, - candidate.batch_size, - candidate.train_steps, + batch_size, + train_steps, seq_len, ) relative_error = abs(achieved - budget) / budget @@ -157,10 +160,11 @@ def test_generate_isoflop_train_args_snapshot(): for i, (args, expected) in enumerate(zip(result, EXPECTED_ISOFLOP_CONFIGS_3E18, strict=True)): assert isinstance(args, IsoFlopTrainArgs) - c = args.candidate - assert c.batch_size == expected["batch_size"], f"Config {i}: batch_size mismatch" - assert c.train_steps == expected["train_steps"], f"Config {i}: train_steps mismatch" - assert c.flops_budget == expected["flops_budget"], f"Config {i}: flops_budget mismatch" + # batch_size and train_steps are computed from recipe + batch_size, train_steps = recipe.compute_training_schedule(args.candidate, MARIN_TOKENIZER_VOCAB_SIZE) + assert batch_size == expected["batch_size"], f"Config {i}: batch_size mismatch" + assert train_steps == expected["train_steps"], f"Config {i}: train_steps mismatch" + assert args.candidate.flops_budget == expected["flops_budget"], f"Config {i}: flops_budget mismatch" # --- End-to-end integration test --- From cb3939cc5aa5ffe4f543aa2664e3deee6f09c8fd Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 10 Jan 2026 01:12:27 +0000 Subject: [PATCH 67/79] Refactor: move vocab_size inside ScalingRecipe This refactor moves vocab_size entirely inside the recipe, simplifying the API by removing it from function signatures throughout the scaling laws library. Key changes: - Add vocab_size as a required property on ScalingRecipe protocol - Add vocab_size field to Marin2025Recipe (defaults to 128256) - Remove vocab_size parameter from all ScalingRecipe methods: - build_model_config() - estimate_memory_bytes() - build_optimizer_config() - candidate_configs() - compute_training_schedule() - Remove vocab_size parameter from library functions: - generate_isoflop_train_args() - predict_optimal_config() - predict_optimal_configs_for_budgets() - pick_v5p_type() - Remove MARIN_TOKENIZER_VOCAB_SIZE constant (no longer needed) - Remove vocab_size lookup and threading from scaling_ladder.py - Remove tokenizer field from ScalingLadderRungConfig (recipe owns this) - Update tests to use new API Benefits: - Simpler API: no vocab_size threading through 10+ function calls - Single source of truth: recipe owns its vocab_size - Less error-prone: impossible to accidentally mix vocab_sizes - Better encapsulation: recipe is self-contained for model config The low-level helper functions (compute_training_flops, solve_for_batch_size, solve_for_train_steps) still take vocab_size as they're called by the recipe with self.vocab_size. --- experiments/isoflop_sweep.py | 69 ++++++++++--------- .../marin/scaling_laws/isoflop_analysis.py | 62 +++++++++-------- .../src/marin/scaling_laws/scaling_ladder.py | 30 ++++---- lib/marin/src/marin/scaling_laws/tpu_utils.py | 6 +- tests/test_scaling_laws.py | 18 ++--- 5 files changed, 88 insertions(+), 97 deletions(-) diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index 4a4ad9b7ab..10b081dc84 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -42,7 +42,7 @@ from experiments.tootsie.exp1295_32b import nemotron_mix from fray.cluster import ResourceConfig from marin.execution.executor import ExecutorStep, InputName, executor_main -from marin.processing.tokenize import get_vocab_size_for_tokenizer, lm_mixture_data_config +from marin.processing.tokenize import lm_mixture_data_config from marin.scaling_laws import ( CandidateConfig, FitScalingLawsResult, @@ -212,9 +212,13 @@ class Marin2025Recipe: """Marin 2025 scaling recipe with all hyperparameters and formulas. This recipe implements all the Marin-specific decisions for scaling experiments. + The vocab_size is derived from the tokenizer, making the recipe self-contained + for all model configuration decisions. """ name: str = "marin-2025" + vocab_size: int = 128256 + """Vocabulary size for the tokenizer. Default is for stanford-crfm/marin-tokenizer.""" # --- Learning rate scaling --- # lr = lr_constant * sqrt(batch_size) / hidden_dim @@ -277,14 +281,14 @@ def _get_step_size(self, budget: float) -> int: return self.large_budget_step_size return self.small_budget_step_size - def _compute_params_for_hidden_size(self, hidden_size: int, vocab_size: int) -> int: + def _compute_params_for_hidden_size(self, hidden_size: int) -> int: """Compute approximate parameter count for a given hidden size.""" num_layers = self.compute_num_layers(hidden_size) intermediate_dim = hidden_size * self.mlp_ratio n_heads = max(1, hidden_size // self.hidden_head_ratio) head_size = hidden_size // n_heads - embed_params = vocab_size * hidden_size * 2 + embed_params = self.vocab_size * hidden_size * 2 q_proj = hidden_size * head_size * n_heads kv_proj = 2 * hidden_size * head_size * n_heads o_proj = head_size * n_heads * hidden_size @@ -297,16 +301,16 @@ def _compute_params_for_hidden_size(self, hidden_size: int, vocab_size: int) -> return embed_params + total_layer_params + final_norm - def hidden_size_for_params(self, target_params: int, vocab_size: int) -> int: + def hidden_size_for_params(self, target_params: int) -> int: """Find the hidden size that gives approximately target_params.""" min_hidden = 2**self.min_hidden_pow max_hidden = 2**self.max_hidden_pow best_hidden = min_hidden - best_diff = abs(self._compute_params_for_hidden_size(min_hidden, vocab_size) - target_params) + best_diff = abs(self._compute_params_for_hidden_size(min_hidden) - target_params) for hidden_size in range(min_hidden, max_hidden + 1, 64): - params = self._compute_params_for_hidden_size(hidden_size, vocab_size) + params = self._compute_params_for_hidden_size(hidden_size) diff = abs(params - target_params) if diff < best_diff: best_diff = diff @@ -314,9 +318,9 @@ def hidden_size_for_params(self, target_params: int, vocab_size: int) -> int: return best_hidden - def build_model_config(self, target_params: int, vocab_size: int, seq_len: int = DEFAULT_SEQ_LEN) -> LlamaConfig: + def build_model_config(self, target_params: int, seq_len: int = DEFAULT_SEQ_LEN) -> LlamaConfig: """Build a Qwen3 model config for a target parameter count.""" - hidden_size = self.hidden_size_for_params(target_params, vocab_size) + hidden_size = self.hidden_size_for_params(target_params) num_layers = self.compute_num_layers(hidden_size) intermediate_dim = hidden_size * self.mlp_ratio n_heads = max(1, hidden_size // self.hidden_head_ratio) @@ -350,30 +354,29 @@ def _build_model_config_from_hidden_size(self, hidden_size: int, seq_len: int = def estimate_memory_bytes( self, candidate: CandidateConfig, - vocab_size: int, seq_len: int = DEFAULT_SEQ_LEN, optim_mult: int = 3, dtype_size: int = 4, fudge_factor: float = 2.0, ) -> int: """Estimate float32 memory usage in bytes for training.""" - hidden_size = self.hidden_size_for_params(candidate.target_params, vocab_size) + hidden_size = self.hidden_size_for_params(candidate.target_params) model_config = self._build_model_config_from_hidden_size(hidden_size, seq_len) - batch_size, _ = self.compute_training_schedule(candidate, vocab_size, seq_len) + batch_size, _ = self.compute_training_schedule(candidate, seq_len) - param_count = self._compute_params_for_hidden_size(hidden_size, vocab_size) + param_count = self._compute_params_for_hidden_size(hidden_size) param_bytes = param_count * optim_mult * dtype_size act_bytes = (batch_size * model_config.max_seq_len) * ( - (model_config.hidden_dim * model_config.num_layers) + vocab_size * fudge_factor + (model_config.hidden_dim * model_config.num_layers) + self.vocab_size * fudge_factor ) total_bytes = param_bytes + act_bytes return int(total_bytes * fudge_factor) def compute_training_schedule( - self, candidate: CandidateConfig, vocab_size: int, seq_len: int = DEFAULT_SEQ_LEN + self, candidate: CandidateConfig, seq_len: int = DEFAULT_SEQ_LEN ) -> tuple[int, int]: """Compute training schedule (batch_size, train_steps) for a candidate.""" - hidden_size = self.hidden_size_for_params(candidate.target_params, vocab_size) + hidden_size = self.hidden_size_for_params(candidate.target_params) # Start with batch_size that gives us ~DEFAULT_STEPS_PER_RUN steps for the tokens target_steps = DEFAULT_STEPS_PER_RUN @@ -396,11 +399,11 @@ def compute_training_schedule( return (batch_size, train_steps) def build_optimizer_config( - self, candidate: CandidateConfig, vocab_size: int, seq_len: int = DEFAULT_SEQ_LEN + self, candidate: CandidateConfig, seq_len: int = DEFAULT_SEQ_LEN ) -> OptimizerConfig: """Build optimizer config for a candidate.""" - batch_size, _ = self.compute_training_schedule(candidate, vocab_size, seq_len) - hidden_size = self.hidden_size_for_params(candidate.target_params, vocab_size) + batch_size, _ = self.compute_training_schedule(candidate, seq_len) + hidden_size = self.hidden_size_for_params(candidate.target_params) learning_rate = self._compute_learning_rate(batch_size, hidden_size) beta2 = self._compute_beta2(batch_size) @@ -421,7 +424,6 @@ def build_optimizer_config( def candidate_configs( self, budget: float, - vocab_size: int, seq_len: int = DEFAULT_SEQ_LEN, steps_per_run: int = DEFAULT_STEPS_PER_RUN, flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, @@ -440,7 +442,7 @@ def candidate_configs( model_config = self._build_model_config_from_hidden_size(hidden_size, seq_len) # Compute batch_size to hit FLOP budget - batch_exact = solve_for_batch_size(model_config, vocab_size, budget, steps_per_run, seq_len) + batch_exact = solve_for_batch_size(model_config, self.vocab_size, budget, steps_per_run, seq_len) batch_size = _round_to_power_of_two(batch_exact) # Adjust batch_size to respect learning rate constraints @@ -452,15 +454,17 @@ def candidate_configs( if batch_size < self.min_batch_size: continue - train_steps = round(solve_for_train_steps(model_config, vocab_size, budget, batch_size, seq_len)) + train_steps = round(solve_for_train_steps(model_config, self.vocab_size, budget, batch_size, seq_len)) # Validate achieved FLOPs are within tolerance - achieved_flops = 3 * model_config.flops_per_token(vocab_size, seq_len) * batch_size * train_steps * seq_len + achieved_flops = ( + 3 * model_config.flops_per_token(self.vocab_size, seq_len) * batch_size * train_steps * seq_len + ) if abs(achieved_flops - budget) / budget > flop_tolerance: continue tokens = batch_size * train_steps * seq_len - target_params = self._compute_params_for_hidden_size(hidden_size, vocab_size) + target_params = self._compute_params_for_hidden_size(hidden_size) # Yield simplified CandidateConfig (without batch_size/train_steps) yield CandidateConfig( @@ -593,7 +597,6 @@ def create_isoflop_sweep_steps( experiment_name: str, recipe: ScalingRecipe, budgets: tuple[float, ...] = DEFAULT_BUDGETS, - tokenizer: str = "stanford-crfm/marin-tokenizer", eval_tasks: tuple[EvalTaskConfig, ...] | None = None, seq_len: int = 4096, ) -> tuple[list[ExecutorStep], list[CandidateConfig]]: @@ -605,22 +608,20 @@ def create_isoflop_sweep_steps( Args: tokenized: Tokenized dataset to train on. experiment_name: Name suffix for the experiment (e.g., 'nemo', 'dclm'). - recipe: ScalingRecipe with hyperparameters - must be explicitly specified. + recipe: ScalingRecipe with hyperparameters (includes vocab_size). budgets: FLOP budgets to sweep over. - tokenizer: Tokenizer to use for vocab size. eval_tasks: Optional evaluation tasks to run after training. + seq_len: Sequence length for training. Returns: A tuple of: - steps: Training and evaluation ExecutorSteps for the sweep. - candidates: CandidateConfig for each training run with full config details. """ - vocab_size = get_vocab_size_for_tokenizer(tokenizer) - # Library provides the training arguments (model configs, optimizer configs, etc.) + # vocab_size is owned by the recipe train_args_list = generate_isoflop_train_args( budgets=budgets, - vocab_size=vocab_size, recipe=recipe, ) @@ -640,13 +641,13 @@ def create_isoflop_sweep_steps( for args in train_args_list: candidate = args.candidate - # Build model and optimizer configs using the recipe - model_config = recipe.build_model_config(candidate.target_params, vocab_size, seq_len) - optimizer_config = recipe.build_optimizer_config(candidate, vocab_size, seq_len) - tpu_type = pick_v5p_type(candidate, vocab_size, seq_len, recipe) + # Build model and optimizer configs using the recipe (vocab_size is owned by recipe) + model_config = recipe.build_model_config(candidate.target_params, seq_len) + optimizer_config = recipe.build_optimizer_config(candidate, seq_len) + tpu_type = pick_v5p_type(candidate, seq_len, recipe) # Compute training schedule from recipe - batch_size, num_steps = recipe.compute_training_schedule(candidate, vocab_size, seq_len) + batch_size, num_steps = recipe.compute_training_schedule(candidate, seq_len) # Use local naming with architecture details for backward compatibility run_name = _format_run_name( diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index b371d6705a..d86e4772cf 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -50,9 +50,6 @@ DEFAULT_EVAL_METRIC_KEY = "eval/paloma/c4_en/bpb" SEQ_LEN = 4096 -# Marin tokenizer vocab size (stanford-crfm/marin-tokenizer) -MARIN_TOKENIZER_VOCAB_SIZE = 128256 - # ---------------- IsoFLOP Sweep Constants ---------------- # Budgets in training FLOPs (includes 3x multiplier for forward + backward pass). # This matches how FLOPs are tracked in WandB via Levanter's log_performance_stats. @@ -169,18 +166,25 @@ class ScalingRecipe(Protocol): model-specific methods. Orchestration logic (generating training args, predicting optimal configs) is handled by library functions that use these core methods. + + The recipe owns the vocab_size, which is derived from the tokenizer choice. + This ensures consistency and simplifies the API by not requiring vocab_size + to be threaded through every function call. """ name: str """Name identifying this recipe (e.g., 'marin-2025').""" + vocab_size: int + """Vocabulary size for the tokenizer used with this recipe.""" + def build_model_config( - self, target_params: int, vocab_size: int, seq_len: int = DEFAULT_SEQ_LEN + self, target_params: int, seq_len: int = DEFAULT_SEQ_LEN ) -> ModelConfiguration: """Build a model config for a target parameter count.""" ... - def estimate_memory_bytes(self, candidate: CandidateConfig, vocab_size: int, seq_len: int = DEFAULT_SEQ_LEN) -> int: + def estimate_memory_bytes(self, candidate: CandidateConfig, seq_len: int = DEFAULT_SEQ_LEN) -> int: """Estimate memory usage in bytes for training a candidate configuration. The implementation can access candidate.target_params, candidate.tokens, and @@ -190,7 +194,7 @@ def estimate_memory_bytes(self, candidate: CandidateConfig, vocab_size: int, seq ... def build_optimizer_config( - self, candidate: CandidateConfig, vocab_size: int, seq_len: int = DEFAULT_SEQ_LEN + self, candidate: CandidateConfig, seq_len: int = DEFAULT_SEQ_LEN ) -> OptimizerConfig: """Build optimizer config for a candidate.""" ... @@ -198,7 +202,6 @@ def build_optimizer_config( def candidate_configs( self, budget: float, - vocab_size: int, seq_len: int = DEFAULT_SEQ_LEN, steps_per_run: int = DEFAULT_STEPS_PER_RUN, flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, @@ -216,6 +219,12 @@ def candidate_configs( """ ... + def compute_training_schedule( + self, candidate: CandidateConfig, seq_len: int = DEFAULT_SEQ_LEN + ) -> tuple[int, int]: + """Compute training schedule (batch_size, train_steps) for a candidate.""" + ... + @dataclass class IsoFlopTrainArgs: @@ -229,11 +238,11 @@ class IsoFlopTrainArgs: the responsibility of experiment code which may have its own conventions. Example: - >>> args = generate_isoflop_train_args(budgets, vocab_size, recipe)[0] - >>> # Recipe converts candidate to model-specific configs - >>> model_config = recipe.build_model_config(args.candidate.target_params, vocab_size) - >>> batch_size, train_steps = recipe.compute_training_schedule(args.candidate, vocab_size) - >>> optimizer_config = recipe.build_optimizer_config(args.candidate, batch_size, vocab_size) + >>> args = generate_isoflop_train_args(budgets, recipe)[0] + >>> # Recipe converts candidate to model-specific configs (vocab_size is owned by recipe) + >>> model_config = recipe.build_model_config(args.candidate.target_params) + >>> batch_size, train_steps = recipe.compute_training_schedule(args.candidate) + >>> optimizer_config = recipe.build_optimizer_config(args.candidate) """ candidate: CandidateConfig @@ -382,7 +391,6 @@ def solve_for_train_steps( def generate_isoflop_train_args( budgets: Sequence[float], - vocab_size: int, recipe: ScalingRecipe, seq_len: int = DEFAULT_SEQ_LEN, steps_per_run: int = DEFAULT_STEPS_PER_RUN, @@ -396,8 +404,7 @@ def generate_isoflop_train_args( Args: budgets: Sequence of FLOP budgets to generate configs for. - vocab_size: Vocabulary size for the tokenizer. - recipe: ScalingRecipe with architecture/hyperparameter settings. + recipe: ScalingRecipe with architecture/hyperparameter settings (includes vocab_size). seq_len: Sequence length for training. steps_per_run: Reference step count for FLOP budget calculation. flop_tolerance: Tolerance for matching FLOP budget. @@ -408,24 +415,23 @@ def generate_isoflop_train_args( Example: >>> from marin.scaling_laws import generate_isoflop_train_args, DEFAULT_BUDGETS >>> # Use a concrete recipe implementation (e.g., from experiments/isoflop_sweep.py) - >>> # recipe = Marin2025Recipe() + >>> # recipe = Marin2025Recipe() # vocab_size is a property of the recipe >>> train_args = generate_isoflop_train_args( ... budgets=DEFAULT_BUDGETS, - ... vocab_size=128256, ... recipe=recipe, ... ) >>> for args in train_args: ... # Recipe converts model-agnostic candidate to model-specific configs - ... model_config = recipe.build_model_config(args.candidate.target_params, vocab_size) - ... batch_size, train_steps = recipe.compute_training_schedule(args.candidate, vocab_size) - ... optimizer_config = recipe.build_optimizer_config(args.candidate, batch_size, vocab_size) + ... model_config = recipe.build_model_config(args.candidate.target_params) + ... batch_size, train_steps = recipe.compute_training_schedule(args.candidate) + ... optimizer_config = recipe.build_optimizer_config(args.candidate) """ results: list[IsoFlopTrainArgs] = [] for budget in budgets: - for candidate in recipe.candidate_configs(budget, vocab_size, seq_len, steps_per_run, flop_tolerance): + for candidate in recipe.candidate_configs(budget, seq_len, steps_per_run, flop_tolerance): # Compute training schedule from recipe (for tags) - batch_size, train_steps = recipe.compute_training_schedule(candidate, vocab_size, seq_len) + batch_size, train_steps = recipe.compute_training_schedule(candidate, seq_len) tags = ( f"FLOPs={budget:.1e}", @@ -588,7 +594,6 @@ def predict_optimal_config( scaling_fits: dict[str, ScalingFit], target_flops: float, label: str, - vocab_size: int, recipe: ScalingRecipe, seq_len: int = DEFAULT_SEQ_LEN, steps_per_run: int = DEFAULT_STEPS_PER_RUN, @@ -607,8 +612,7 @@ def predict_optimal_config( scaling_fits: Dict of {label: ScalingFit} from scaling ladder result. target_flops: Target compute budget in FLOPs. label: Dataset/experiment label to use for scaling fit. - vocab_size: Vocabulary size. - recipe: ScalingRecipe with architecture/hyperparameter settings. + recipe: ScalingRecipe with architecture/hyperparameter settings (includes vocab_size). seq_len: Sequence length for training. steps_per_run: Reference step count for FLOP budget calculation. flop_tolerance: Tolerance for matching FLOP budget. @@ -626,7 +630,7 @@ def predict_optimal_config( logger.info(f"Predicted optimal tokens for {target_flops:.2e} FLOPs: {optimal_tokens:.2e}") - candidates = list(recipe.candidate_configs(target_flops, vocab_size, seq_len, steps_per_run, flop_tolerance)) + candidates = list(recipe.candidate_configs(target_flops, seq_len, steps_per_run, flop_tolerance)) if not candidates: logger.warning(f"No valid candidates found for budget {target_flops:.2e}") @@ -648,7 +652,6 @@ def predict_optimal_configs_for_budgets( scaling_fits: dict[str, ScalingFit], target_budgets: list[float], label: str, - vocab_size: int, recipe: ScalingRecipe, seq_len: int = DEFAULT_SEQ_LEN, steps_per_run: int = DEFAULT_STEPS_PER_RUN, @@ -660,8 +663,7 @@ def predict_optimal_configs_for_budgets( scaling_fits: Dict of {label: ScalingFit} from scaling ladder result. target_budgets: List of target compute budgets in FLOPs. label: Dataset/experiment label to use for scaling fit. - vocab_size: Vocabulary size. - recipe: ScalingRecipe with architecture/hyperparameter settings. + recipe: ScalingRecipe with architecture/hyperparameter settings (includes vocab_size). seq_len: Sequence length for training. steps_per_run: Reference step count for FLOP budget calculation. flop_tolerance: Tolerance for matching FLOP budget. @@ -675,7 +677,7 @@ def predict_optimal_configs_for_budgets( configs = [] for budget in target_budgets: config = predict_optimal_config( - scaling_fits, budget, label, vocab_size, recipe, seq_len, steps_per_run, flop_tolerance + scaling_fits, budget, label, recipe, seq_len, steps_per_run, flop_tolerance ) if config is None: raise RuntimeError( diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py index 82d7805f52..c779ee15a5 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -35,7 +35,6 @@ from levanter.trainer import TrainerConfig from levanter.utils.mesh import MeshConfig -from marin.processing.tokenize import get_vocab_size_for_tokenizer from marin.scaling_laws.isoflop_analysis import ( ScalingFit, ScalingRecipe, @@ -55,7 +54,8 @@ class ScalingLadderRungConfig: the target compute budget. At runtime, the optimal config is loaded from the analysis output. - The ScalingRecipe handles all model-specific decisions (architecture, optimizer). + The ScalingRecipe handles all model-specific decisions (architecture, optimizer) + and owns the vocab_size (derived from the tokenizer choice). """ analysis_output_path: str @@ -74,10 +74,7 @@ class ScalingLadderRungConfig: """Where to write training outputs.""" recipe: ScalingRecipe - """Scaling recipe that handles model/optimizer config building.""" - - tokenizer: str = "stanford-crfm/marin-tokenizer" - """Tokenizer to use.""" + """Scaling recipe that handles model/optimizer config building (includes vocab_size).""" seq_len: int = 4096 """Sequence length for training.""" @@ -86,10 +83,10 @@ class ScalingLadderRungConfig: def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: """Run one rung of the scaling ladder (one compute-optimal training run). - The recipe handles all model-specific decisions: - - Model config is built via `recipe.build_model_config(target_params, vocab_size)` - - Training schedule is built via `recipe.compute_training_schedule(candidate, vocab_size)` - - Optimizer config is built via `recipe.build_optimizer_config(candidate, batch_size, vocab_size)` + The recipe handles all model-specific decisions (vocab_size is owned by the recipe): + - Model config is built via `recipe.build_model_config(target_params)` + - Training schedule is built via `recipe.compute_training_schedule(candidate)` + - Optimizer config is built via `recipe.build_optimizer_config(candidate)` """ result_path = os.path.join(config.analysis_output_path, "isoflop_analysis_result.json") fs, _, _ = fsspec.get_fs_token_paths(result_path) @@ -103,13 +100,10 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: raise ValueError(f"Expected 2 scaling fit values for '{key}', got {len(value)}") scaling_fits[key] = ScalingFit(float(value[0]), float(value[1])) - vocab_size = get_vocab_size_for_tokenizer(config.tokenizer) - candidate = predict_optimal_config( scaling_fits=scaling_fits, target_flops=config.target_budget, label=config.label, - vocab_size=vocab_size, recipe=config.recipe, seq_len=config.seq_len, ) @@ -125,12 +119,12 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: f" tokens={candidate.tokens:.2e}" ) - model_cfg = config.recipe.build_model_config(candidate.target_params, vocab_size, config.seq_len) - optimizer_cfg = config.recipe.build_optimizer_config(candidate, vocab_size, config.seq_len) - tpu_type = pick_v5p_type(candidate, vocab_size, config.seq_len, config.recipe) + model_cfg = config.recipe.build_model_config(candidate.target_params, config.seq_len) + optimizer_cfg = config.recipe.build_optimizer_config(candidate, config.seq_len) + tpu_type = pick_v5p_type(candidate, config.seq_len, config.recipe) - # Compute training schedule - recipe-specific, not in protocol - batch_size, train_steps = config.recipe.compute_training_schedule(candidate, vocab_size, config.seq_len) + # Compute training schedule - recipe-specific + batch_size, train_steps = config.recipe.compute_training_schedule(candidate, config.seq_len) train_config = TrainLmConfig( data=config.tokenized, diff --git a/lib/marin/src/marin/scaling_laws/tpu_utils.py b/lib/marin/src/marin/scaling_laws/tpu_utils.py index 2a95b63c6f..044a708cf0 100644 --- a/lib/marin/src/marin/scaling_laws/tpu_utils.py +++ b/lib/marin/src/marin/scaling_laws/tpu_utils.py @@ -37,7 +37,6 @@ def pick_v5p_type( candidate: CandidateConfig, - vocab_size: int, seq_len: int, recipe: ScalingRecipe, ) -> str: @@ -45,9 +44,8 @@ def pick_v5p_type( Args: candidate: CandidateConfig with target_params and tokens. - vocab_size: Vocabulary size. seq_len: Sequence length. - recipe: ScalingRecipe for memory estimation. + recipe: ScalingRecipe for memory estimation (includes vocab_size). Returns: TPU slice name, e.g., "v5p-8" or "v5p-32". @@ -55,7 +53,7 @@ def pick_v5p_type( Raises: ValueError: If the model is too large for available v5p slices. """ - need_bytes = recipe.estimate_memory_bytes(candidate, vocab_size, seq_len) + need_bytes = recipe.estimate_memory_bytes(candidate, seq_len) chip_bytes = HBM_PER_CHIP_GIB * 1024**3 chips = math.ceil(need_bytes / chip_bytes) cores_req = chips * CORES_PER_CHIP diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py index edfb89c934..a43a9ea8a0 100644 --- a/tests/test_scaling_laws.py +++ b/tests/test_scaling_laws.py @@ -26,7 +26,6 @@ from marin.scaling_laws.isoflop_analysis import ( DEFAULT_SEQ_LEN, IsoFlopTrainArgs, - MARIN_TOKENIZER_VOCAB_SIZE, compute_training_flops, fit_scaling_laws, generate_isoflop_train_args, @@ -98,15 +97,13 @@ def test_candidate_configs_within_tolerance(): flop_tolerance = 0.01 seq_len = DEFAULT_SEQ_LEN - for candidate in recipe.candidate_configs( - budget, MARIN_TOKENIZER_VOCAB_SIZE, seq_len, flop_tolerance=flop_tolerance - ): - # Compute training schedule from recipe - batch_size, train_steps = recipe.compute_training_schedule(candidate, MARIN_TOKENIZER_VOCAB_SIZE, seq_len) - model_config = recipe.build_model_config(candidate.target_params, MARIN_TOKENIZER_VOCAB_SIZE, seq_len) + for candidate in recipe.candidate_configs(budget, seq_len, flop_tolerance=flop_tolerance): + # Compute training schedule from recipe (vocab_size is owned by recipe) + batch_size, train_steps = recipe.compute_training_schedule(candidate, seq_len) + model_config = recipe.build_model_config(candidate.target_params, seq_len) achieved = compute_training_flops( model_config, - MARIN_TOKENIZER_VOCAB_SIZE, + recipe.vocab_size, batch_size, train_steps, seq_len, @@ -152,7 +149,6 @@ def test_generate_isoflop_train_args_snapshot(): recipe = Marin2025Recipe() result = generate_isoflop_train_args( budgets=(3e18,), - vocab_size=MARIN_TOKENIZER_VOCAB_SIZE, recipe=recipe, ) @@ -160,8 +156,8 @@ def test_generate_isoflop_train_args_snapshot(): for i, (args, expected) in enumerate(zip(result, EXPECTED_ISOFLOP_CONFIGS_3E18, strict=True)): assert isinstance(args, IsoFlopTrainArgs) - # batch_size and train_steps are computed from recipe - batch_size, train_steps = recipe.compute_training_schedule(args.candidate, MARIN_TOKENIZER_VOCAB_SIZE) + # batch_size and train_steps are computed from recipe (vocab_size is owned by recipe) + batch_size, train_steps = recipe.compute_training_schedule(args.candidate) assert batch_size == expected["batch_size"], f"Config {i}: batch_size mismatch" assert train_steps == expected["train_steps"], f"Config {i}: train_steps mismatch" assert args.candidate.flops_budget == expected["flops_budget"], f"Config {i}: flops_budget mismatch" From b0efefa5e5992c565949f48d5958937716808e63 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 10 Jan 2026 02:43:09 +0000 Subject: [PATCH 68/79] Use get_vocab_size_for_tokenizer to derive vocab_size from tokenizer Instead of hardcoding vocab_size, derive it from the tokenizer field using the get_vocab_size_for_tokenizer utility. This maintains the connection between tokenizer and vocab_size and makes it clear where the value comes from. The utility function is already cached with @lru_cache, so repeated calls are efficient. --- experiments/isoflop_sweep.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index 10b081dc84..7dfcd62761 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -42,7 +42,7 @@ from experiments.tootsie.exp1295_32b import nemotron_mix from fray.cluster import ResourceConfig from marin.execution.executor import ExecutorStep, InputName, executor_main -from marin.processing.tokenize import lm_mixture_data_config +from marin.processing.tokenize import get_vocab_size_for_tokenizer, lm_mixture_data_config from marin.scaling_laws import ( CandidateConfig, FitScalingLawsResult, @@ -217,8 +217,13 @@ class Marin2025Recipe: """ name: str = "marin-2025" - vocab_size: int = 128256 - """Vocabulary size for the tokenizer. Default is for stanford-crfm/marin-tokenizer.""" + tokenizer: str = "stanford-crfm/marin-tokenizer" + """Tokenizer to use. vocab_size is derived from this.""" + + @property + def vocab_size(self) -> int: + """Vocabulary size derived from the tokenizer.""" + return get_vocab_size_for_tokenizer(self.tokenizer) # --- Learning rate scaling --- # lr = lr_constant * sqrt(batch_size) / hidden_dim From 78620f71c7396c95a85be5c70c99495db90a999f Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 12 Jan 2026 11:29:06 -0800 Subject: [PATCH 69/79] Weird Roundtripping --- experiments/exp1603_subgroup_evals.py | 23 ++-- experiments/isoflop_sweep.py | 84 ++------------ .../marin/scaling_laws/isoflop_analysis.py | 109 +++++++++--------- .../src/marin/scaling_laws/scaling_ladder.py | 9 +- lib/marin/src/marin/scaling_laws/tpu_utils.py | 2 +- tests/test_scaling_laws.py | 2 +- 6 files changed, 83 insertions(+), 146 deletions(-) diff --git a/experiments/exp1603_subgroup_evals.py b/experiments/exp1603_subgroup_evals.py index 30e63b0fd0..ba0ddcbe6f 100644 --- a/experiments/exp1603_subgroup_evals.py +++ b/experiments/exp1603_subgroup_evals.py @@ -20,7 +20,7 @@ from experiments.llama import llama3_tokenizer from experiments.exp1342_gemstones_scaling_law import distributional_eval_sets -from experiments.isoflop_sweep import MARIN_2025_RECIPE, MARIN_SCALING_SUITES +from experiments.isoflop_sweep import MARIN_SCALING_SUITES from experiments.models import ModelConfig, download_model_step from marin.execution.executor import executor_main, output_path_of, versioned from marin.evaluation.log_probs import default_lm_log_probs @@ -45,8 +45,9 @@ def create_eval_steps() -> list: steps = [] dist_eval = distributional_eval_sets(llama3_tokenizer) for model, candidate in list(zip(*MARIN_SCALING_SUITES["nemotron"], strict=False)): - total_tokens = candidate.batch_size * candidate.train_steps * 4096 - name = f"marin-nemo-{candidate.flops_budget:.0e}C-{total_tokens}T-" f"N{candidate.target_params:.0e}" + total_tokens = int(candidate.tokens) + params = candidate.model_config.total_trainable_params(VOCAB_SIZE) + name = f"marin-nemo-{candidate.flops_budget:.0e}C-{total_tokens}T-N{params:.0e}" step = evaluate_levanter_lm_evaluation_harness( model_name=name, @@ -56,7 +57,7 @@ def create_eval_steps() -> list: ) steps.append(step) - model_config = MARIN_2025_RECIPE.build_model_config(candidate.target_params, VOCAB_SIZE) + model_config = candidate.model_config logprobs_step = default_lm_log_probs( output_path_of(model).cd("checkpoints"), model_config, @@ -69,8 +70,9 @@ def create_eval_steps() -> list: steps.append(logprobs_step) for model, candidate in list(zip(*MARIN_SCALING_SUITES["common_pile"], strict=False)): - total_tokens = candidate.batch_size * candidate.train_steps * 4096 - name = f"marin-comma-{candidate.flops_budget:.0e}C-{total_tokens}T-" f"N{candidate.target_params:.0e}" + total_tokens = int(candidate.tokens) + params = candidate.model_config.total_trainable_params(VOCAB_SIZE) + name = f"marin-comma-{candidate.flops_budget:.0e}C-{total_tokens}T-N{params:.0e}" step = evaluate_levanter_lm_evaluation_harness( model_name=name, @@ -80,7 +82,7 @@ def create_eval_steps() -> list: ) steps.append(step) - model_config = MARIN_2025_RECIPE.build_model_config(candidate.target_params, VOCAB_SIZE) + model_config = candidate.model_config logprobs_step = default_lm_log_probs( output_path_of(model).cd("checkpoints"), model_config, @@ -93,8 +95,9 @@ def create_eval_steps() -> list: steps.append(logprobs_step) for model, candidate in list(zip(*MARIN_SCALING_SUITES["dclm-default"], strict=False)): - total_tokens = candidate.batch_size * candidate.train_steps * 4096 - name = f"marin-dclm-{candidate.flops_budget:.0e}C-{total_tokens}T-" f"N{candidate.target_params:.0e}" + total_tokens = int(candidate.tokens) + params = candidate.model_config.total_trainable_params(VOCAB_SIZE) + name = f"marin-dclm-{candidate.flops_budget:.0e}C-{total_tokens}T-N{params:.0e}" step = evaluate_levanter_lm_evaluation_harness( model_name=name, @@ -104,7 +107,7 @@ def create_eval_steps() -> list: ) steps.append(step) - model_config = MARIN_2025_RECIPE.build_model_config(candidate.target_params, VOCAB_SIZE) + model_config = candidate.model_config logprobs_step = default_lm_log_probs( output_path_of(model).cd("checkpoints"), model_config, diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index 7dfcd62761..7bf109c86a 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -286,60 +286,6 @@ def _get_step_size(self, budget: float) -> int: return self.large_budget_step_size return self.small_budget_step_size - def _compute_params_for_hidden_size(self, hidden_size: int) -> int: - """Compute approximate parameter count for a given hidden size.""" - num_layers = self.compute_num_layers(hidden_size) - intermediate_dim = hidden_size * self.mlp_ratio - n_heads = max(1, hidden_size // self.hidden_head_ratio) - head_size = hidden_size // n_heads - - embed_params = self.vocab_size * hidden_size * 2 - q_proj = hidden_size * head_size * n_heads - kv_proj = 2 * hidden_size * head_size * n_heads - o_proj = head_size * n_heads * hidden_size - attn_params = q_proj + kv_proj + o_proj - mlp_params = 3 * hidden_size * intermediate_dim - norm_params = 2 * hidden_size - layer_params = attn_params + mlp_params + norm_params - total_layer_params = num_layers * layer_params - final_norm = hidden_size - - return embed_params + total_layer_params + final_norm - - def hidden_size_for_params(self, target_params: int) -> int: - """Find the hidden size that gives approximately target_params.""" - min_hidden = 2**self.min_hidden_pow - max_hidden = 2**self.max_hidden_pow - - best_hidden = min_hidden - best_diff = abs(self._compute_params_for_hidden_size(min_hidden) - target_params) - - for hidden_size in range(min_hidden, max_hidden + 1, 64): - params = self._compute_params_for_hidden_size(hidden_size) - diff = abs(params - target_params) - if diff < best_diff: - best_diff = diff - best_hidden = hidden_size - - return best_hidden - - def build_model_config(self, target_params: int, seq_len: int = DEFAULT_SEQ_LEN) -> LlamaConfig: - """Build a Qwen3 model config for a target parameter count.""" - hidden_size = self.hidden_size_for_params(target_params) - num_layers = self.compute_num_layers(hidden_size) - intermediate_dim = hidden_size * self.mlp_ratio - n_heads = max(1, hidden_size // self.hidden_head_ratio) - - return Qwen3Config( - hidden_dim=hidden_size, - intermediate_dim=intermediate_dim, - num_layers=num_layers, - num_heads=n_heads, - num_kv_heads=n_heads, - max_seq_len=seq_len, - rope=Llama3RotaryEmbeddingsConfig(), - ) - def _build_model_config_from_hidden_size(self, hidden_size: int, seq_len: int = DEFAULT_SEQ_LEN) -> LlamaConfig: """Build model config from hidden_size directly.""" num_layers = self.compute_num_layers(hidden_size) @@ -365,11 +311,10 @@ def estimate_memory_bytes( fudge_factor: float = 2.0, ) -> int: """Estimate float32 memory usage in bytes for training.""" - hidden_size = self.hidden_size_for_params(candidate.target_params) - model_config = self._build_model_config_from_hidden_size(hidden_size, seq_len) + model_config = candidate.model_config batch_size, _ = self.compute_training_schedule(candidate, seq_len) - param_count = self._compute_params_for_hidden_size(hidden_size) + param_count = model_config.total_trainable_params(self.vocab_size) param_bytes = param_count * optim_mult * dtype_size act_bytes = (batch_size * model_config.max_seq_len) * ( (model_config.hidden_dim * model_config.num_layers) + self.vocab_size * fudge_factor @@ -377,11 +322,9 @@ def estimate_memory_bytes( total_bytes = param_bytes + act_bytes return int(total_bytes * fudge_factor) - def compute_training_schedule( - self, candidate: CandidateConfig, seq_len: int = DEFAULT_SEQ_LEN - ) -> tuple[int, int]: + def compute_training_schedule(self, candidate: CandidateConfig, seq_len: int = DEFAULT_SEQ_LEN) -> tuple[int, int]: """Compute training schedule (batch_size, train_steps) for a candidate.""" - hidden_size = self.hidden_size_for_params(candidate.target_params) + hidden_size = candidate.model_config.hidden_dim # Start with batch_size that gives us ~DEFAULT_STEPS_PER_RUN steps for the tokens target_steps = DEFAULT_STEPS_PER_RUN @@ -403,12 +346,10 @@ def compute_training_schedule( return (batch_size, train_steps) - def build_optimizer_config( - self, candidate: CandidateConfig, seq_len: int = DEFAULT_SEQ_LEN - ) -> OptimizerConfig: + def build_optimizer_config(self, candidate: CandidateConfig, seq_len: int = DEFAULT_SEQ_LEN) -> OptimizerConfig: """Build optimizer config for a candidate.""" batch_size, _ = self.compute_training_schedule(candidate, seq_len) - hidden_size = self.hidden_size_for_params(candidate.target_params) + hidden_size = candidate.model_config.hidden_dim learning_rate = self._compute_learning_rate(batch_size, hidden_size) beta2 = self._compute_beta2(batch_size) @@ -435,9 +376,8 @@ def candidate_configs( ) -> Iterator[CandidateConfig]: """Yield candidate configurations within the FLOP budget. - Iterates over hidden sizes, computes batch_size to hit the FLOP budget, - validates constraints (LR, min batch size), and yields CandidateConfigs - for valid configurations. + Iterates over feasible model architectures, computes tokens to hit the + FLOP budget, and yields CandidateConfigs with the model_config directly. """ step_size = self._get_step_size(budget) min_hidden = 2**self.min_hidden_pow @@ -469,11 +409,9 @@ def candidate_configs( continue tokens = batch_size * train_steps * seq_len - target_params = self._compute_params_for_hidden_size(hidden_size) - # Yield simplified CandidateConfig (without batch_size/train_steps) yield CandidateConfig( - target_params=target_params, + model_config=model_config, tokens=tokens, flops_budget=budget, ) @@ -646,8 +584,8 @@ def create_isoflop_sweep_steps( for args in train_args_list: candidate = args.candidate - # Build model and optimizer configs using the recipe (vocab_size is owned by recipe) - model_config = recipe.build_model_config(candidate.target_params, seq_len) + # Model config is on the candidate; build optimizer config using the recipe + model_config = candidate.model_config optimizer_config = recipe.build_optimizer_config(candidate, seq_len) tpu_type = pick_v5p_type(candidate, seq_len, recipe) diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index d86e4772cf..5249015f07 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -121,23 +121,45 @@ class IsoFlopRecord: DEFAULT_FLOP_TOLERANCE = 0.01 # Relative error tolerance for FLOP budget +# ---------------- Model Configuration Protocol ---------------- + + +class ModelConfiguration(Protocol): + """Protocol for model configs used in scaling law calculations. + + Any model config that implements these methods can be used with the + scaling law functions. This allows the library to be model-agnostic + while still working with LlamaConfig, QwenConfig, etc. + """ + + def flops_per_token(self, vocab_size: int, seq_len: int) -> float: + """Return FLOPs per token for this model configuration.""" + ... + + def total_trainable_params(self, vocab_size: int) -> int: + """Return total trainable parameter count for this model configuration.""" + ... + + # ---------------- Candidate Config ---------------- @dataclass class CandidateConfig: - """Model-agnostic compute allocation from scaling law analysis. + """Compute allocation from scaling law analysis. - Contains only the fundamental parameters that scaling laws reason about: - - How much compute (flops_budget) - - How to allocate it between model size (target_params) and data (tokens) + Contains the model configuration and training parameters: + - model_config: The actual model architecture (satisfies ModelConfiguration protocol) + - tokens: How many tokens to train on + - flops_budget: The compute budget this config was generated for - The training schedule (batch_size, train_steps) is computed by the - ScalingRecipe at training time via compute_training_schedule(). + Parameter count is derived from model_config.total_trainable_params(vocab_size). + Training schedule (batch_size, train_steps) is computed by the ScalingRecipe + at training time via compute_training_schedule(). """ - target_params: int - """Optimal parameter count for this flops_budget.""" + model_config: ModelConfiguration + """Model configuration for this candidate.""" tokens: float """Total tokens to train on.""" @@ -146,19 +168,6 @@ class CandidateConfig: """Compute budget this config was generated for.""" -class ModelConfiguration(Protocol): - """Protocol for model configs used in scaling law calculations. - - Any model config that implements flops_per_token can be used with the - scaling law functions. This allows the library to be model-agnostic - while still working with LlamaConfig, QwenConfig, etc. - """ - - def flops_per_token(self, vocab_size: int, seq_len: int) -> float: - """Return FLOPs per token for this model configuration.""" - ... - - class ScalingRecipe(Protocol): """Protocol defining the interface for scaling law recipes. @@ -178,24 +187,16 @@ class ScalingRecipe(Protocol): vocab_size: int """Vocabulary size for the tokenizer used with this recipe.""" - def build_model_config( - self, target_params: int, seq_len: int = DEFAULT_SEQ_LEN - ) -> ModelConfiguration: - """Build a model config for a target parameter count.""" - ... - def estimate_memory_bytes(self, candidate: CandidateConfig, seq_len: int = DEFAULT_SEQ_LEN) -> int: """Estimate memory usage in bytes for training a candidate configuration. - The implementation can access candidate.target_params, candidate.tokens, and + The implementation can access candidate.model_config, candidate.tokens, and candidate.flops_budget to compute memory requirements. This allows the recipe to compute the actual batch_size (from tokens) when estimating memory. """ ... - def build_optimizer_config( - self, candidate: CandidateConfig, seq_len: int = DEFAULT_SEQ_LEN - ) -> OptimizerConfig: + def build_optimizer_config(self, candidate: CandidateConfig, seq_len: int = DEFAULT_SEQ_LEN) -> OptimizerConfig: """Build optimizer config for a candidate.""" ... @@ -208,20 +209,18 @@ def candidate_configs( ) -> Iterator[CandidateConfig]: """Yield candidate configurations within the FLOP budget. - A typical implementation will iterate over hidden sizes (which determine - target_params), compute the batch_size needed to hit the FLOP budget, - and yield configs where the relative FLOP error is within tolerance. + Each candidate includes the model_config directly. A typical implementation + will iterate over feasible model architectures, compute the tokens needed + to hit the FLOP budget, and yield configs where the relative FLOP error + is within tolerance. The implementation should handle model-specific constraints like: - Hidden size increments (e.g., multiples of 64 or 128) - Memory constraints affecting maximum batch size - - Architecture-specific parameter count formulas """ ... - def compute_training_schedule( - self, candidate: CandidateConfig, seq_len: int = DEFAULT_SEQ_LEN - ) -> tuple[int, int]: + def compute_training_schedule(self, candidate: CandidateConfig, seq_len: int = DEFAULT_SEQ_LEN) -> tuple[int, int]: """Compute training schedule (batch_size, train_steps) for a candidate.""" ... @@ -230,23 +229,22 @@ def compute_training_schedule( class IsoFlopTrainArgs: """Arguments needed to set up an isoflop training run. - This dataclass contains the model-agnostic parameters needed for training. - The ScalingRecipe is responsible for converting these to model-specific - configs (model architecture, optimizer hyperparameters, training schedule). + This dataclass contains the parameters needed for training. + The ScalingRecipe is responsible for computing training schedules + and optimizer hyperparameters from the candidate. Naming (run_name, output_path) is intentionally not included here - that's the responsibility of experiment code which may have its own conventions. Example: >>> args = generate_isoflop_train_args(budgets, recipe)[0] - >>> # Recipe converts candidate to model-specific configs (vocab_size is owned by recipe) - >>> model_config = recipe.build_model_config(args.candidate.target_params) + >>> model_config = args.candidate.model_config # Model config is on the candidate >>> batch_size, train_steps = recipe.compute_training_schedule(args.candidate) >>> optimizer_config = recipe.build_optimizer_config(args.candidate) """ candidate: CandidateConfig - """Model-agnostic compute allocation (target_params, tokens, flops_budget).""" + """Compute allocation (model_config, tokens, flops_budget).""" tags: tuple[str, ...] """Tags for tracking/filtering runs.""" @@ -396,11 +394,11 @@ def generate_isoflop_train_args( steps_per_run: int = DEFAULT_STEPS_PER_RUN, flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, ) -> list[IsoFlopTrainArgs]: - """Generate model-agnostic training arguments for each candidate in an isoflop sweep. + """Generate training arguments for each candidate in an isoflop sweep. - Returns IsoFlopTrainArgs containing model-agnostic CandidateConfig objects. - Use recipe.build_model_config() and recipe.build_optimizer_config() to get - model-specific configs. Naming (run_name, output_path) is left to the caller. + Returns IsoFlopTrainArgs containing CandidateConfig objects with model configs. + Use recipe.build_optimizer_config() to get optimizer configs. + Naming (run_name, output_path) is left to the caller. Args: budgets: Sequence of FLOP budgets to generate configs for. @@ -421,8 +419,7 @@ def generate_isoflop_train_args( ... recipe=recipe, ... ) >>> for args in train_args: - ... # Recipe converts model-agnostic candidate to model-specific configs - ... model_config = recipe.build_model_config(args.candidate.target_params) + ... model_config = args.candidate.model_config # Model config is on the candidate ... batch_size, train_steps = recipe.compute_training_schedule(args.candidate) ... optimizer_config = recipe.build_optimizer_config(args.candidate) """ @@ -432,10 +429,11 @@ def generate_isoflop_train_args( for candidate in recipe.candidate_configs(budget, seq_len, steps_per_run, flop_tolerance): # Compute training schedule from recipe (for tags) batch_size, train_steps = recipe.compute_training_schedule(candidate, seq_len) + params = candidate.model_config.total_trainable_params(recipe.vocab_size) tags = ( f"FLOPs={budget:.1e}", - f"N={candidate.target_params:.1e}", + f"N={params:.1e}", f"B={batch_size}", f"steps={train_steps}", f"tokens={candidate.tokens:.1e}", @@ -641,9 +639,8 @@ def predict_optimal_config( if best.tokens < optimal_tokens: best = max(candidates, key=lambda c: c.tokens) - logger.info( - f"Selected config: N={best.target_params:.2e}, " f"tokens={best.tokens:.2e} (optimal: {optimal_tokens:.2e})" - ) + params = best.model_config.total_trainable_params(recipe.vocab_size) + logger.info(f"Selected config: N={params:.2e}, tokens={best.tokens:.2e} (optimal: {optimal_tokens:.2e})") return best @@ -676,9 +673,7 @@ def predict_optimal_configs_for_budgets( """ configs = [] for budget in target_budgets: - config = predict_optimal_config( - scaling_fits, budget, label, recipe, seq_len, steps_per_run, flop_tolerance - ) + config = predict_optimal_config(scaling_fits, budget, label, recipe, seq_len, steps_per_run, flop_tolerance) if config is None: raise RuntimeError( f"Failed to predict optimal config for budget {budget:.2e} FLOPs " diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py index c779ee15a5..5b27f84579 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -84,7 +84,7 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: """Run one rung of the scaling ladder (one compute-optimal training run). The recipe handles all model-specific decisions (vocab_size is owned by the recipe): - - Model config is built via `recipe.build_model_config(target_params)` + - Model config is on the candidate: `candidate.model_config` - Training schedule is built via `recipe.compute_training_schedule(candidate)` - Optimizer config is built via `recipe.build_optimizer_config(candidate)` """ @@ -113,13 +113,14 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: f"Could not find optimal config for budget {config.target_budget:.2e} and label '{config.label}'" ) + params = candidate.model_config.total_trainable_params(config.recipe.vocab_size) logger.info( f"Training with optimal config for {config.target_budget:.2e} FLOPs:\n" - f" target_params={candidate.target_params:.2e}\n" + f" params={params:.2e}\n" f" tokens={candidate.tokens:.2e}" ) - model_cfg = config.recipe.build_model_config(candidate.target_params, config.seq_len) + model_cfg = candidate.model_config optimizer_cfg = config.recipe.build_optimizer_config(candidate, config.seq_len) tpu_type = pick_v5p_type(candidate, config.seq_len, config.recipe) @@ -135,7 +136,7 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: "optimal-training", f"FLOPs={config.target_budget:.1e}", f"label={config.label}", - f"N={candidate.target_params:.1e}", + f"N={params:.1e}", ], ), mp=jmp.get_policy("p=f32,c=bfloat16"), diff --git a/lib/marin/src/marin/scaling_laws/tpu_utils.py b/lib/marin/src/marin/scaling_laws/tpu_utils.py index 044a708cf0..d12bfe8fcc 100644 --- a/lib/marin/src/marin/scaling_laws/tpu_utils.py +++ b/lib/marin/src/marin/scaling_laws/tpu_utils.py @@ -43,7 +43,7 @@ def pick_v5p_type( """Select the smallest TPU v5p slice that fits the model in float32. Args: - candidate: CandidateConfig with target_params and tokens. + candidate: CandidateConfig with model_config and tokens. seq_len: Sequence length. recipe: ScalingRecipe for memory estimation (includes vocab_size). diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py index a43a9ea8a0..049439a14e 100644 --- a/tests/test_scaling_laws.py +++ b/tests/test_scaling_laws.py @@ -100,7 +100,7 @@ def test_candidate_configs_within_tolerance(): for candidate in recipe.candidate_configs(budget, seq_len, flop_tolerance=flop_tolerance): # Compute training schedule from recipe (vocab_size is owned by recipe) batch_size, train_steps = recipe.compute_training_schedule(candidate, seq_len) - model_config = recipe.build_model_config(candidate.target_params, seq_len) + model_config = candidate.model_config achieved = compute_training_flops( model_config, recipe.vocab_size, From 916ca6a046d9922d8ee004cc9c8a6811638311fd Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 12 Jan 2026 16:33:26 -0800 Subject: [PATCH 70/79] Requested Error --- experiments/isoflop_sweep.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index 7bf109c86a..3296790e45 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -288,6 +288,11 @@ def _get_step_size(self, budget: float) -> int: def _build_model_config_from_hidden_size(self, hidden_size: int, seq_len: int = DEFAULT_SEQ_LEN) -> LlamaConfig: """Build model config from hidden_size directly.""" + if hidden_size % self.hidden_head_ratio != 0: + raise ValueError( + f"hidden_size ({hidden_size}) must be divisible by hidden_head_ratio ({self.hidden_head_ratio}). " + f"Got remainder {hidden_size % self.hidden_head_ratio}." + ) num_layers = self.compute_num_layers(hidden_size) intermediate_dim = hidden_size * self.mlp_ratio n_heads = max(1, hidden_size // self.hidden_head_ratio) From 75d8e1a171f94a18c9e43df94e5668b7f358b85d Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 12 Jan 2026 17:06:44 -0800 Subject: [PATCH 71/79] Refactor --- experiments/isoflop_sweep.py | 162 +++++++-------- lib/marin/src/marin/scaling_laws/__init__.py | 10 +- .../marin/scaling_laws/isoflop_analysis.py | 194 +++++++----------- .../src/marin/scaling_laws/scaling_ladder.py | 15 +- tests/test_scaling_laws.py | 44 ++-- 5 files changed, 179 insertions(+), 246 deletions(-) diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index 3296790e45..ac5873a549 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -30,7 +30,6 @@ from levanter.models.llama import LlamaConfig from levanter.models.qwen import Qwen3Config from levanter.optim.cautious import CautiousConfig -from levanter.optim.config import OptimizerConfig from experiments.evals.evals import default_eval from experiments.evals.task_configs import EvalTaskConfig @@ -49,11 +48,9 @@ IsoFlopRecord, ScalingRecipe, fit_scaling_laws, - generate_isoflop_train_args, + generate_training_configs, pick_v5p_type, round_flops_to_bucket, - solve_for_batch_size, - solve_for_train_steps, ) from marin.scaling_laws.eval_metrics_reader import read_raw_records from marin.utilities.wandb_utils import WANDB_ENTITY, WANDB_PROJECT @@ -64,7 +61,6 @@ LEGACY_BUDGETS: tuple[float, ...] = (3e18, 9e18, 1.8e19, 3e19, 9e19, 1.8e20, 3e20) DEFAULT_SEQ_LEN: int = 4096 DEFAULT_STEPS_PER_RUN: int = 2**16 -DEFAULT_FLOP_TOLERANCE: float = 0.01 # ---------------- Levanter WandB Metric Keys ---------------- # These keys correspond to the metrics logged by Levanter's training callbacks. @@ -317,7 +313,7 @@ def estimate_memory_bytes( ) -> int: """Estimate float32 memory usage in bytes for training.""" model_config = candidate.model_config - batch_size, _ = self.compute_training_schedule(candidate, seq_len) + batch_size = candidate.batch_size param_count = model_config.total_trainable_params(self.vocab_size) param_bytes = param_count * optim_mult * dtype_size @@ -327,39 +323,58 @@ def estimate_memory_bytes( total_bytes = param_bytes + act_bytes return int(total_bytes * fudge_factor) - def compute_training_schedule(self, candidate: CandidateConfig, seq_len: int = DEFAULT_SEQ_LEN) -> tuple[int, int]: - """Compute training schedule (batch_size, train_steps) for a candidate.""" - hidden_size = candidate.model_config.hidden_dim + def build_model_configs( + self, + budget: float, + seq_len: int = DEFAULT_SEQ_LEN, + ) -> Iterator[LlamaConfig]: + """Yield candidate model architectures for the given FLOP budget.""" + step_size = self._get_step_size(budget) + min_hidden = 2**self.min_hidden_pow + max_hidden = 2**self.max_hidden_pow + + for hidden_size in range(min_hidden, max_hidden + 1, step_size): + yield self._build_model_config_from_hidden_size(hidden_size, seq_len) + + def build_candidate_config( + self, + model_config: LlamaConfig, + tokens: float, + flops_budget: float, + seq_len: int = DEFAULT_SEQ_LEN, + ) -> CandidateConfig | None: + """Build complete training config for a model and token count. - # Start with batch_size that gives us ~DEFAULT_STEPS_PER_RUN steps for the tokens + Returns None if the configuration is invalid (e.g., batch_size < minimum + after learning rate constraints are applied). + """ + hidden_size = model_config.hidden_dim + + # Start with batch_size that gives us ~DEFAULT_STEPS_PER_RUN steps target_steps = DEFAULT_STEPS_PER_RUN - batch_exact = candidate.tokens / (target_steps * seq_len) + batch_exact = tokens / (target_steps * seq_len) batch_size = _round_to_power_of_two(batch_exact) # Adjust batch_size to respect learning rate constraints lr = self._compute_learning_rate(batch_size, hidden_size) - while lr > self.max_learning_rate and batch_size >= self.min_batch_size * 2: + while lr > self.max_learning_rate: batch_size //= 2 lr = self._compute_learning_rate(batch_size, hidden_size) - # Ensure minimum batch size + # Return None if batch_size is below minimum if batch_size < self.min_batch_size: - batch_size = self.min_batch_size + return None # Compute train_steps to achieve target tokens - train_steps = round(candidate.tokens / (batch_size * seq_len)) + train_steps = round(tokens / (batch_size * seq_len)) - return (batch_size, train_steps) + # Compute actual tokens after rounding + actual_tokens = batch_size * train_steps * seq_len - def build_optimizer_config(self, candidate: CandidateConfig, seq_len: int = DEFAULT_SEQ_LEN) -> OptimizerConfig: - """Build optimizer config for a candidate.""" - batch_size, _ = self.compute_training_schedule(candidate, seq_len) - hidden_size = candidate.model_config.hidden_dim - learning_rate = self._compute_learning_rate(batch_size, hidden_size) + # Build optimizer config beta2 = self._compute_beta2(batch_size) - - return CautiousConfig( - learning_rate=learning_rate, + optimizer_config = CautiousConfig( + learning_rate=lr, weight_decay=self.weight_decay, min_lr_ratio=self.min_lr_ratio, warmup=self.warmup, @@ -372,54 +387,14 @@ def build_optimizer_config(self, candidate: CandidateConfig, seq_len: int = DEFA decay=self.decay, ) - def candidate_configs( - self, - budget: float, - seq_len: int = DEFAULT_SEQ_LEN, - steps_per_run: int = DEFAULT_STEPS_PER_RUN, - flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, - ) -> Iterator[CandidateConfig]: - """Yield candidate configurations within the FLOP budget. - - Iterates over feasible model architectures, computes tokens to hit the - FLOP budget, and yields CandidateConfigs with the model_config directly. - """ - step_size = self._get_step_size(budget) - min_hidden = 2**self.min_hidden_pow - max_hidden = 2**self.max_hidden_pow - - for hidden_size in range(min_hidden, max_hidden + 1, step_size): - model_config = self._build_model_config_from_hidden_size(hidden_size, seq_len) - - # Compute batch_size to hit FLOP budget - batch_exact = solve_for_batch_size(model_config, self.vocab_size, budget, steps_per_run, seq_len) - batch_size = _round_to_power_of_two(batch_exact) - - # Adjust batch_size to respect learning rate constraints - lr = self._compute_learning_rate(batch_size, hidden_size) - while lr > self.max_learning_rate: - batch_size //= 2 - lr = self._compute_learning_rate(batch_size, hidden_size) - - if batch_size < self.min_batch_size: - continue - - train_steps = round(solve_for_train_steps(model_config, self.vocab_size, budget, batch_size, seq_len)) - - # Validate achieved FLOPs are within tolerance - achieved_flops = ( - 3 * model_config.flops_per_token(self.vocab_size, seq_len) * batch_size * train_steps * seq_len - ) - if abs(achieved_flops - budget) / budget > flop_tolerance: - continue - - tokens = batch_size * train_steps * seq_len - - yield CandidateConfig( - model_config=model_config, - tokens=tokens, - flops_budget=budget, - ) + return CandidateConfig( + model_config=model_config, + optimizer_config=optimizer_config, + batch_size=batch_size, + train_steps=train_steps, + tokens=actual_tokens, + flops_budget=flops_budget, + ) MARIN_2025_RECIPE = Marin2025Recipe() @@ -551,7 +526,7 @@ def create_isoflop_sweep_steps( """Create ExecutorSteps for an ISOFlop sweep. This function creates ExecutorSteps directly in experiment code, using - `generate_isoflop_train_args()` from the library to compute configs. + `generate_training_configs()` from the library to compute configs. Args: tokenized: Tokenized dataset to train on. @@ -566,14 +541,14 @@ def create_isoflop_sweep_steps( - steps: Training and evaluation ExecutorSteps for the sweep. - candidates: CandidateConfig for each training run with full config details. """ - # Library provides the training arguments (model configs, optimizer configs, etc.) - # vocab_size is owned by the recipe - train_args_list = generate_isoflop_train_args( + # Generate complete training configs from the library + candidates = generate_training_configs( budgets=budgets, recipe=recipe, + seq_len=seq_len, ) - # Base config for training runs (values overridden per-candidate via optimizer_config) + # Base config for training runs (values overridden per-candidate) base_train_config = SimpleTrainConfig( resources=ResourceConfig.with_tpu("v5p-8"), train_batch_size=1, @@ -583,37 +558,39 @@ def create_isoflop_sweep_steps( train_steps: list[ExecutorStep] = [] eval_steps: list[ExecutorStep] = [] - candidates: list[CandidateConfig] = [] # Create ExecutorSteps for each candidate configuration - for args in train_args_list: - candidate = args.candidate - - # Model config is on the candidate; build optimizer config using the recipe + for candidate in candidates: model_config = candidate.model_config - optimizer_config = recipe.build_optimizer_config(candidate, seq_len) tpu_type = pick_v5p_type(candidate, seq_len, recipe) - # Compute training schedule from recipe - batch_size, num_steps = recipe.compute_training_schedule(candidate, seq_len) - # Use local naming with architecture details for backward compatibility run_name = _format_run_name( candidate.flops_budget, model_config.hidden_dim, model_config.num_layers, - batch_size, + candidate.batch_size, experiment_name, ) output_path = f"checkpoints/isoflop/{run_name}" + # Build tags for tracking + params = model_config.total_trainable_params(recipe.vocab_size) + tags = ( + f"FLOPs={candidate.flops_budget:.1e}", + f"N={params:.1e}", + f"B={candidate.batch_size}", + f"steps={candidate.train_steps}", + f"tokens={candidate.tokens:.1e}", + ) + train_cfg = replace( base_train_config, - train_batch_size=batch_size, - learning_rate=optimizer_config.learning_rate, - num_train_steps=num_steps, + train_batch_size=candidate.batch_size, + learning_rate=candidate.optimizer_config.learning_rate, + num_train_steps=candidate.train_steps, resources=ResourceConfig.with_tpu(tpu_type), - optimizer_config=optimizer_config, + optimizer_config=candidate.optimizer_config, ) # Create training step @@ -623,13 +600,12 @@ def create_isoflop_sweep_steps( model_config=model_config, train_config=train_cfg, eval_harness_tasks=[], - tags=args.tags, + tags=tags, ) # Pin to static output path for checkpoint reuse train_step = train_step.with_output_path(output_path) train_steps.append(train_step) - candidates.append(candidate) # Create evaluation step if eval tasks specified if eval_tasks: diff --git a/lib/marin/src/marin/scaling_laws/__init__.py b/lib/marin/src/marin/scaling_laws/__init__.py index 1d9b09d112..e8ce6c431e 100644 --- a/lib/marin/src/marin/scaling_laws/__init__.py +++ b/lib/marin/src/marin/scaling_laws/__init__.py @@ -15,13 +15,10 @@ from marin.scaling_laws.isoflop_analysis import ( DEFAULT_BUDGETS, DEFAULT_EVAL_METRIC_KEY, - DEFAULT_FLOP_TOLERANCE, DEFAULT_SEQ_LEN, - DEFAULT_STEPS_PER_RUN, CandidateConfig, FitScalingLawsResult, IsoFlopRecord, - IsoFlopTrainArgs, MinimaRecord, ModelConfiguration, QuadraticFitCoeffs, @@ -29,7 +26,7 @@ ScalingRecipe, compute_training_flops, fit_scaling_laws, - generate_isoflop_train_args, + generate_training_configs, predict_optimal_config, predict_optimal_configs_for_budgets, round_flops_to_bucket, @@ -54,14 +51,11 @@ # Constants "DEFAULT_BUDGETS", "DEFAULT_EVAL_METRIC_KEY", - "DEFAULT_FLOP_TOLERANCE", "DEFAULT_SEQ_LEN", - "DEFAULT_STEPS_PER_RUN", # Data classes and Protocols "CandidateConfig", "FitScalingLawsResult", "IsoFlopRecord", - "IsoFlopTrainArgs", "MinimaRecord", "ModelConfiguration", "QuadraticFitCoeffs", @@ -73,7 +67,7 @@ "create_isoflop_plot", "create_scaling_plot", "fit_scaling_laws", - "generate_isoflop_train_args", + "generate_training_configs", "pick_v5p_type", "predict_optimal_config", "predict_optimal_configs_for_budgets", diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index 5249015f07..46df3f89e4 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -21,12 +21,12 @@ Key types: - IsoFlopRecord: The contract for a single training run's metrics - FitScalingLawsResult: Output from fit_scaling_laws() -- CandidateConfig: Model-agnostic compute allocation from scaling law analysis +- CandidateConfig: Complete training configuration (model, optimizer, schedule) Key functions: - fit_scaling_laws(records): Fit scaling laws from typed records - predict_optimal_config(): Predict optimal training config for a target budget -- generate_isoflop_train_args(): Generate training args for an isoflop sweep +- generate_training_configs(): Generate training configs for an isoflop sweep """ import logging @@ -117,8 +117,6 @@ class IsoFlopRecord: # ---------------- IsoFLOP Sweep Defaults ---------------- DEFAULT_SEQ_LEN = SEQ_LEN -DEFAULT_STEPS_PER_RUN = 2**16 # Reference step count for hyperparameter tuning -DEFAULT_FLOP_TOLERANCE = 0.01 # Relative error tolerance for FLOP budget # ---------------- Model Configuration Protocol ---------------- @@ -146,21 +144,29 @@ def total_trainable_params(self, vocab_size: int) -> int: @dataclass class CandidateConfig: - """Compute allocation from scaling law analysis. - - Contains the model configuration and training parameters: - - model_config: The actual model architecture (satisfies ModelConfiguration protocol) - - tokens: How many tokens to train on + """Complete training configuration for a scaling law candidate. + + Contains everything needed to run a training job: + - model_config: The model architecture + - optimizer_config: Optimizer with learning rate, beta2, etc. + - batch_size: Training batch size + - train_steps: Number of training steps + - tokens: Total tokens to train on (batch_size * train_steps * seq_len) - flops_budget: The compute budget this config was generated for - - Parameter count is derived from model_config.total_trainable_params(vocab_size). - Training schedule (batch_size, train_steps) is computed by the ScalingRecipe - at training time via compute_training_schedule(). """ model_config: ModelConfiguration """Model configuration for this candidate.""" + optimizer_config: OptimizerConfig + """Optimizer configuration with learning rate, weight decay, etc.""" + + batch_size: int + """Training batch size.""" + + train_steps: int + """Number of training steps.""" + tokens: float """Total tokens to train on.""" @@ -172,7 +178,7 @@ class ScalingRecipe(Protocol): """Protocol defining the interface for scaling law recipes. Concrete implementations (e.g., Marin2025Recipe) should implement these - model-specific methods. Orchestration logic (generating training args, + model-specific methods. Orchestration logic (generating training configs, predicting optimal configs) is handled by library functions that use these core methods. @@ -188,66 +194,37 @@ class ScalingRecipe(Protocol): """Vocabulary size for the tokenizer used with this recipe.""" def estimate_memory_bytes(self, candidate: CandidateConfig, seq_len: int = DEFAULT_SEQ_LEN) -> int: - """Estimate memory usage in bytes for training a candidate configuration. - - The implementation can access candidate.model_config, candidate.tokens, and - candidate.flops_budget to compute memory requirements. This allows the recipe - to compute the actual batch_size (from tokens) when estimating memory. - """ + """Estimate memory usage in bytes for training a candidate configuration.""" ... - def build_optimizer_config(self, candidate: CandidateConfig, seq_len: int = DEFAULT_SEQ_LEN) -> OptimizerConfig: - """Build optimizer config for a candidate.""" - ... - - def candidate_configs( + def build_model_configs( self, budget: float, seq_len: int = DEFAULT_SEQ_LEN, - steps_per_run: int = DEFAULT_STEPS_PER_RUN, - flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, - ) -> Iterator[CandidateConfig]: - """Yield candidate configurations within the FLOP budget. - - Each candidate includes the model_config directly. A typical implementation - will iterate over feasible model architectures, compute the tokens needed - to hit the FLOP budget, and yield configs where the relative FLOP error - is within tolerance. - - The implementation should handle model-specific constraints like: - - Hidden size increments (e.g., multiples of 64 or 128) - - Memory constraints affecting maximum batch size - """ - ... + ) -> Iterator[ModelConfiguration]: + """Yield candidate model architectures for the given FLOP budget. - def compute_training_schedule(self, candidate: CandidateConfig, seq_len: int = DEFAULT_SEQ_LEN) -> tuple[int, int]: - """Compute training schedule (batch_size, train_steps) for a candidate.""" + A typical implementation will iterate over hidden sizes (the primary + architectural knob) and yield model configs for each feasible size. + """ ... + def build_candidate_config( + self, + model_config: ModelConfiguration, + tokens: float, + flops_budget: float, + seq_len: int = DEFAULT_SEQ_LEN, + ) -> CandidateConfig | None: + """Build complete training config for a model and token count. -@dataclass -class IsoFlopTrainArgs: - """Arguments needed to set up an isoflop training run. - - This dataclass contains the parameters needed for training. - The ScalingRecipe is responsible for computing training schedules - and optimizer hyperparameters from the candidate. - - Naming (run_name, output_path) is intentionally not included here - that's - the responsibility of experiment code which may have its own conventions. - - Example: - >>> args = generate_isoflop_train_args(budgets, recipe)[0] - >>> model_config = args.candidate.model_config # Model config is on the candidate - >>> batch_size, train_steps = recipe.compute_training_schedule(args.candidate) - >>> optimizer_config = recipe.build_optimizer_config(args.candidate) - """ - - candidate: CandidateConfig - """Compute allocation (model_config, tokens, flops_budget).""" + Solves for batch_size, computes optimizer hyperparameters (learning rate, + beta2, etc.), and returns a complete CandidateConfig. - tags: tuple[str, ...] - """Tags for tracking/filtering runs.""" + Returns None if the configuration is invalid (e.g., batch_size < minimum + after learning rate constraints are applied). + """ + ... # ---------------- Typed Records ---------------- @@ -384,67 +361,51 @@ def solve_for_train_steps( return target_flops / (3 * flops_per_token * batch_size * seq_len) -# ---------------- Training Args Generation ---------------- +# ---------------- Training Config Generation ---------------- -def generate_isoflop_train_args( +def generate_training_configs( budgets: Sequence[float], recipe: ScalingRecipe, seq_len: int = DEFAULT_SEQ_LEN, - steps_per_run: int = DEFAULT_STEPS_PER_RUN, - flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, -) -> list[IsoFlopTrainArgs]: - """Generate training arguments for each candidate in an isoflop sweep. +) -> list[CandidateConfig]: + """Generate training configurations for an isoflop sweep. - Returns IsoFlopTrainArgs containing CandidateConfig objects with model configs. - Use recipe.build_optimizer_config() to get optimizer configs. - Naming (run_name, output_path) is left to the caller. + For each FLOP budget: + 1. Gets candidate model architectures from the recipe + 2. Computes tokens needed to achieve the budget: tokens = budget / (3 * flops_per_token) + 3. Builds complete training configs via recipe.build_candidate_config() + 4. Filters out invalid configs (where build_candidate_config returns None) Args: budgets: Sequence of FLOP budgets to generate configs for. - recipe: ScalingRecipe with architecture/hyperparameter settings (includes vocab_size). + recipe: ScalingRecipe with architecture/hyperparameter settings. seq_len: Sequence length for training. - steps_per_run: Reference step count for FLOP budget calculation. - flop_tolerance: Tolerance for matching FLOP budget. Returns: - List of IsoFlopTrainArgs, one per candidate config across all budgets. + List of CandidateConfig, each containing model_config, optimizer_config, + batch_size, train_steps, tokens, and flops_budget. Example: - >>> from marin.scaling_laws import generate_isoflop_train_args, DEFAULT_BUDGETS - >>> # Use a concrete recipe implementation (e.g., from experiments/isoflop_sweep.py) - >>> # recipe = Marin2025Recipe() # vocab_size is a property of the recipe - >>> train_args = generate_isoflop_train_args( - ... budgets=DEFAULT_BUDGETS, - ... recipe=recipe, - ... ) - >>> for args in train_args: - ... model_config = args.candidate.model_config # Model config is on the candidate - ... batch_size, train_steps = recipe.compute_training_schedule(args.candidate) - ... optimizer_config = recipe.build_optimizer_config(args.candidate) + >>> from marin.scaling_laws import generate_training_configs, DEFAULT_BUDGETS + >>> configs = generate_training_configs(budgets=DEFAULT_BUDGETS, recipe=recipe) + >>> for cfg in configs: + ... print(f"N={cfg.model_config.total_trainable_params(recipe.vocab_size):.1e}") + ... print(f"batch_size={cfg.batch_size}, steps={cfg.train_steps}") + ... print(f"lr={cfg.optimizer_config.learning_rate}") """ - results: list[IsoFlopTrainArgs] = [] + results: list[CandidateConfig] = [] for budget in budgets: - for candidate in recipe.candidate_configs(budget, seq_len, steps_per_run, flop_tolerance): - # Compute training schedule from recipe (for tags) - batch_size, train_steps = recipe.compute_training_schedule(candidate, seq_len) - params = candidate.model_config.total_trainable_params(recipe.vocab_size) - - tags = ( - f"FLOPs={budget:.1e}", - f"N={params:.1e}", - f"B={batch_size}", - f"steps={train_steps}", - f"tokens={candidate.tokens:.1e}", - ) + for model_config in recipe.build_model_configs(budget, seq_len): + # Compute tokens directly from budget + flops_per_token = model_config.flops_per_token(recipe.vocab_size, seq_len) + tokens = budget / (3 * flops_per_token) - results.append( - IsoFlopTrainArgs( - candidate=candidate, - tags=tags, - ) - ) + # Build complete training config (returns None if invalid) + candidate = recipe.build_candidate_config(model_config, tokens, budget, seq_len) + if candidate is not None: + results.append(candidate) return results @@ -594,8 +555,6 @@ def predict_optimal_config( label: str, recipe: ScalingRecipe, seq_len: int = DEFAULT_SEQ_LEN, - steps_per_run: int = DEFAULT_STEPS_PER_RUN, - flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, ) -> CandidateConfig | None: """Predict optimal training config for a target compute budget using fitted scaling laws. @@ -612,8 +571,6 @@ def predict_optimal_config( label: Dataset/experiment label to use for scaling fit. recipe: ScalingRecipe with architecture/hyperparameter settings (includes vocab_size). seq_len: Sequence length for training. - steps_per_run: Reference step count for FLOP budget calculation. - flop_tolerance: Tolerance for matching FLOP budget. Returns: CandidateConfig for the predicted optimal, or None if label not in fits @@ -628,7 +585,14 @@ def predict_optimal_config( logger.info(f"Predicted optimal tokens for {target_flops:.2e} FLOPs: {optimal_tokens:.2e}") - candidates = list(recipe.candidate_configs(target_flops, seq_len, steps_per_run, flop_tolerance)) + # Build candidates using the new API + candidates: list[CandidateConfig] = [] + for model_config in recipe.build_model_configs(target_flops, seq_len): + flops_per_token = model_config.flops_per_token(recipe.vocab_size, seq_len) + tokens = target_flops / (3 * flops_per_token) + candidate = recipe.build_candidate_config(model_config, tokens, target_flops, seq_len) + if candidate is not None: + candidates.append(candidate) if not candidates: logger.warning(f"No valid candidates found for budget {target_flops:.2e}") @@ -651,8 +615,6 @@ def predict_optimal_configs_for_budgets( label: str, recipe: ScalingRecipe, seq_len: int = DEFAULT_SEQ_LEN, - steps_per_run: int = DEFAULT_STEPS_PER_RUN, - flop_tolerance: float = DEFAULT_FLOP_TOLERANCE, ) -> list[CandidateConfig]: """Predict optimal configs for multiple target compute budgets. @@ -662,8 +624,6 @@ def predict_optimal_configs_for_budgets( label: Dataset/experiment label to use for scaling fit. recipe: ScalingRecipe with architecture/hyperparameter settings (includes vocab_size). seq_len: Sequence length for training. - steps_per_run: Reference step count for FLOP budget calculation. - flop_tolerance: Tolerance for matching FLOP budget. Returns: List of CandidateConfig for each budget. @@ -673,7 +633,7 @@ def predict_optimal_configs_for_budgets( """ configs = [] for budget in target_budgets: - config = predict_optimal_config(scaling_fits, budget, label, recipe, seq_len, steps_per_run, flop_tolerance) + config = predict_optimal_config(scaling_fits, budget, label, recipe, seq_len) if config is None: raise RuntimeError( f"Failed to predict optimal config for budget {budget:.2e} FLOPs " diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py index 5b27f84579..1b2664f964 100644 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ b/lib/marin/src/marin/scaling_laws/scaling_ladder.py @@ -83,10 +83,10 @@ class ScalingLadderRungConfig: def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: """Run one rung of the scaling ladder (one compute-optimal training run). - The recipe handles all model-specific decisions (vocab_size is owned by the recipe): - - Model config is on the candidate: `candidate.model_config` - - Training schedule is built via `recipe.compute_training_schedule(candidate)` - - Optimizer config is built via `recipe.build_optimizer_config(candidate)` + The candidate contains all training configuration: + - Model config: `candidate.model_config` + - Optimizer config: `candidate.optimizer_config` + - Training schedule: `candidate.batch_size`, `candidate.train_steps` """ result_path = os.path.join(config.analysis_output_path, "isoflop_analysis_result.json") fs, _, _ = fsspec.get_fs_token_paths(result_path) @@ -121,12 +121,11 @@ def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: ) model_cfg = candidate.model_config - optimizer_cfg = config.recipe.build_optimizer_config(candidate, config.seq_len) + optimizer_cfg = candidate.optimizer_config + batch_size = candidate.batch_size + train_steps = candidate.train_steps tpu_type = pick_v5p_type(candidate, config.seq_len, config.recipe) - # Compute training schedule - recipe-specific - batch_size, train_steps = config.recipe.compute_training_schedule(candidate, config.seq_len) - train_config = TrainLmConfig( data=config.tokenized, trainer=TrainerConfig( diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py index 049439a14e..de3be0970f 100644 --- a/tests/test_scaling_laws.py +++ b/tests/test_scaling_laws.py @@ -25,10 +25,10 @@ from marin.scaling_laws.isoflop_analysis import ( DEFAULT_SEQ_LEN, - IsoFlopTrainArgs, + CandidateConfig, compute_training_flops, fit_scaling_laws, - generate_isoflop_train_args, + generate_training_configs, robust_quad_logx, solve_for_batch_size, solve_for_train_steps, @@ -97,15 +97,20 @@ def test_candidate_configs_within_tolerance(): flop_tolerance = 0.01 seq_len = DEFAULT_SEQ_LEN - for candidate in recipe.candidate_configs(budget, seq_len, flop_tolerance=flop_tolerance): - # Compute training schedule from recipe (vocab_size is owned by recipe) - batch_size, train_steps = recipe.compute_training_schedule(candidate, seq_len) - model_config = candidate.model_config + # Generate candidates using the new API + for model_config in recipe.build_model_configs(budget, seq_len): + flops_per_token = model_config.flops_per_token(recipe.vocab_size, seq_len) + tokens = budget / (3 * flops_per_token) + candidate = recipe.build_candidate_config(model_config, tokens, budget, seq_len) + + if candidate is None: + continue + achieved = compute_training_flops( - model_config, + candidate.model_config, recipe.vocab_size, - batch_size, - train_steps, + candidate.batch_size, + candidate.train_steps, seq_len, ) relative_error = abs(achieved - budget) / budget @@ -131,7 +136,7 @@ def test_robust_quad_logx_fits_quadratic(): # --- Snapshot test for config generation --- -# Snapshot of expected output for generate_isoflop_train_args with budget=3e18 training FLOPs. +# Snapshot of expected output for generate_training_configs with budget=3e18 training FLOPs. EXPECTED_ISOFLOP_CONFIGS_3E18 = [ {"batch_size": 32, "train_steps": 32844, "flops_budget": 3e18}, {"batch_size": 16, "train_steps": 46274, "flops_budget": 3e18}, @@ -141,26 +146,25 @@ def test_robust_quad_logx_fits_quadratic(): ] -def test_generate_isoflop_train_args_snapshot(): - """Snapshot test: verify generate_isoflop_train_args produces expected configs. +def test_generate_training_configs_snapshot(): + """Snapshot test: verify generate_training_configs produces expected configs. This ensures reproducibility of the config generation algorithm. """ recipe = Marin2025Recipe() - result = generate_isoflop_train_args( + result = generate_training_configs( budgets=(3e18,), recipe=recipe, ) assert len(result) == len(EXPECTED_ISOFLOP_CONFIGS_3E18) - for i, (args, expected) in enumerate(zip(result, EXPECTED_ISOFLOP_CONFIGS_3E18, strict=True)): - assert isinstance(args, IsoFlopTrainArgs) - # batch_size and train_steps are computed from recipe (vocab_size is owned by recipe) - batch_size, train_steps = recipe.compute_training_schedule(args.candidate) - assert batch_size == expected["batch_size"], f"Config {i}: batch_size mismatch" - assert train_steps == expected["train_steps"], f"Config {i}: train_steps mismatch" - assert args.candidate.flops_budget == expected["flops_budget"], f"Config {i}: flops_budget mismatch" + for i, (candidate, expected) in enumerate(zip(result, EXPECTED_ISOFLOP_CONFIGS_3E18, strict=True)): + assert isinstance(candidate, CandidateConfig) + # batch_size and train_steps are now directly on the candidate + assert candidate.batch_size == expected["batch_size"], f"Config {i}: batch_size mismatch" + assert candidate.train_steps == expected["train_steps"], f"Config {i}: train_steps mismatch" + assert candidate.flops_budget == expected["flops_budget"], f"Config {i}: flops_budget mismatch" # --- End-to-end integration test --- From 101cd87dca01afe60cff51334764f03e054835c8 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 12 Jan 2026 17:20:26 -0800 Subject: [PATCH 72/79] Delete More --- .../exp2166_scaling_ladder_analysis.py | 95 ++++++++-- lib/marin/src/marin/scaling_laws/__init__.py | 14 -- .../marin/scaling_laws/isoflop_analysis.py | 121 +------------ .../src/marin/scaling_laws/scaling_ladder.py | 170 ------------------ tests/test_scaling_laws.py | 47 +---- 5 files changed, 87 insertions(+), 360 deletions(-) delete mode 100644 lib/marin/src/marin/scaling_laws/scaling_ladder.py diff --git a/experiments/exp2166_scaling_ladder_analysis.py b/experiments/exp2166_scaling_ladder_analysis.py index ce224cdbf8..c34aec7f1a 100644 --- a/experiments/exp2166_scaling_ladder_analysis.py +++ b/experiments/exp2166_scaling_ladder_analysis.py @@ -23,7 +23,14 @@ 3. Optionally trains compute-optimal models at larger target budgets """ -from experiments.defaults import default_validation_sets +import json +import logging +import os + +import fsspec +from fray.cluster import ResourceConfig + +from experiments.defaults import default_train, default_validation_sets from experiments.isoflop_sweep import ( IsoFlopAnalysisConfig, MARIN_2025_RECIPE, @@ -31,12 +38,13 @@ nemotron_mix, run_isoflop_analysis_step, ) +from experiments.simple_train_config import SimpleTrainConfig from marin.execution.executor import ExecutorStep, executor_main, output_path_of, this_output_path from marin.processing.tokenize import add_validation_sets_to_mixture -from marin.scaling_laws import ( - ScalingLadderRungConfig, - run_scaling_ladder_rung, -) +from marin.scaling_laws import ScalingFit, predict_optimal_config +from marin.scaling_laws.tpu_utils import pick_v5p_type + +logger = logging.getLogger(__name__) # Get training steps from the isoflop sweep nemotron_training, _ = MARIN_SCALING_SUITES["nemotron"] @@ -46,10 +54,78 @@ EXPERIMENT_NAME = "exp2166-scaling-ladder-nemotron-validation" LABEL = "nemo-wider-depth-adapt" TOKENIZER = "stanford-crfm/marin-tokenizer" +SEQ_LEN = 4096 # Add validation sets to the training mixture nemotron_mix_with_validation = add_validation_sets_to_mixture(nemotron_mix, default_validation_sets(tokenizer=TOKENIZER)) + +def run_optimal_training( + analysis_output_path: str, + target_budget: float, + label: str, +) -> ExecutorStep: + """Create an ExecutorStep for compute-optimal training at the given budget. + + Loads scaling fits from the analysis output, predicts the optimal config, + and returns an ExecutorStep using default_train. + """ + result_path = os.path.join(analysis_output_path, "isoflop_analysis_result.json") + fs, _, _ = fsspec.get_fs_token_paths(result_path) + + with fs.open(result_path, "r") as f: + analysis_result = json.load(f) + + scaling_fits: dict[str, ScalingFit] = {} + for key, value in analysis_result["scaling_fits"].items(): + if len(value) != 2: + raise ValueError(f"Expected 2 scaling fit values for '{key}', got {len(value)}") + scaling_fits[key] = ScalingFit(float(value[0]), float(value[1])) + + candidate = predict_optimal_config( + scaling_fits=scaling_fits, + target_flops=target_budget, + label=label, + recipe=MARIN_2025_RECIPE, + seq_len=SEQ_LEN, + ) + + if candidate is None: + raise RuntimeError(f"Could not find optimal config for budget {target_budget:.2e} and label '{label}'") + + params = candidate.model_config.total_trainable_params(MARIN_2025_RECIPE.vocab_size) + logger.info( + f"Training with optimal config for {target_budget:.2e} FLOPs:\n" + f" params={params:.2e}\n" + f" tokens={candidate.tokens:.2e}" + ) + + tpu_type = pick_v5p_type(candidate, SEQ_LEN, MARIN_2025_RECIPE) + + train_config = SimpleTrainConfig( + resources=ResourceConfig.with_tpu(tpu_type), + train_batch_size=candidate.batch_size, + num_train_steps=candidate.train_steps, + learning_rate=candidate.optimizer_config.learning_rate, + optimizer_config=candidate.optimizer_config, + train_seq_len=SEQ_LEN, + ) + + return default_train( + name=f"{EXPERIMENT_NAME}-optimal-{target_budget:.0e}", + tokenized=nemotron_mix_with_validation, + model_config=candidate.model_config, + train_config=train_config, + tags=[ + "optimal-training", + f"FLOPs={target_budget:.1e}", + f"label={label}", + f"N={params:.1e}", + ], + use_default_validation=False, # Already added above + ) + + # --- Step 1: IsoFLOP Analysis --- # Creates scaling law fits from the training runs analysis_step = ExecutorStep( @@ -68,15 +144,12 @@ for budget in TARGET_BUDGETS: step = ExecutorStep( name=f"{EXPERIMENT_NAME}-optimal-{budget:.0e}", - fn=run_scaling_ladder_rung, - config=ScalingLadderRungConfig( + fn=lambda b=budget: run_optimal_training( analysis_output_path=output_path_of(analysis_step), - target_budget=budget, + target_budget=b, label=LABEL, - tokenized=nemotron_mix_with_validation, - output_path=this_output_path(), - recipe=MARIN_2025_RECIPE, ), + config=None, ) optimal_runs.append(step) diff --git a/lib/marin/src/marin/scaling_laws/__init__.py b/lib/marin/src/marin/scaling_laws/__init__.py index e8ce6c431e..618ef8e833 100644 --- a/lib/marin/src/marin/scaling_laws/__init__.py +++ b/lib/marin/src/marin/scaling_laws/__init__.py @@ -24,22 +24,14 @@ QuadraticFitCoeffs, ScalingFit, ScalingRecipe, - compute_training_flops, fit_scaling_laws, generate_training_configs, predict_optimal_config, - predict_optimal_configs_for_budgets, round_flops_to_bucket, - solve_for_batch_size, - solve_for_train_steps, ) from marin.scaling_laws.tpu_utils import ( pick_v5p_type, ) -from marin.scaling_laws.scaling_ladder import ( - ScalingLadderRungConfig, - run_scaling_ladder_rung, -) from marin.scaling_laws.scaling_plots import ( create_isoflop_plot, create_scaling_plot, @@ -60,21 +52,15 @@ "ModelConfiguration", "QuadraticFitCoeffs", "ScalingFit", - "ScalingLadderRungConfig", "ScalingRecipe", # Functions - "compute_training_flops", "create_isoflop_plot", "create_scaling_plot", "fit_scaling_laws", "generate_training_configs", "pick_v5p_type", "predict_optimal_config", - "predict_optimal_configs_for_budgets", "round_flops_to_bucket", - "run_scaling_ladder_rung", "save_plots", - "solve_for_batch_size", - "solve_for_train_steps", "upload_plots_to_wandb", ] diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index 46df3f89e4..0ddb3bf66c 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -48,7 +48,7 @@ # common loss metric that measures model perplexity on the C4 English dataset. # See: https://arxiv.org/abs/2312.10523 DEFAULT_EVAL_METRIC_KEY = "eval/paloma/c4_en/bpb" -SEQ_LEN = 4096 +DEFAULT_SEQ_LEN = 4096 # ---------------- IsoFLOP Sweep Constants ---------------- # Budgets in training FLOPs (includes 3x multiplier for forward + backward pass). @@ -115,10 +115,6 @@ class IsoFlopRecord: """Experiment label for grouping (e.g., 'nemo', 'dclm').""" -# ---------------- IsoFLOP Sweep Defaults ---------------- -DEFAULT_SEQ_LEN = SEQ_LEN - - # ---------------- Model Configuration Protocol ---------------- @@ -281,86 +277,6 @@ def round_flops_to_bucket(flops: float, base: float = 1.1) -> float: return base ** round(k) -def compute_training_flops( - model_config: ModelConfiguration, - vocab_size: int, - batch_size: int, - train_steps: int, - seq_len: int, -) -> float: - """Compute total training FLOPs using the model config's own method. - - This returns training FLOPs which includes forward pass (1x) + backward pass (2x) = 3x. - This matches the FLOP accounting in Levanter's log_performance_stats callback - (see train_lm.py) and standard ML conventions (e.g., Chinchilla paper). - - Args: - model_config: Model config with flops_per_token method. - vocab_size: Vocabulary size. - batch_size: Training batch size. - train_steps: Number of training steps. - seq_len: Sequence length. - - Returns: - Total training FLOPs (including 3x multiplier for forward + backward pass). - """ - flops_per_token = model_config.flops_per_token(vocab_size, seq_len) - # Multiply by 3 for training: forward (1x) + backward (2x) - return 3 * flops_per_token * batch_size * train_steps * seq_len - - -def solve_for_batch_size( - model_config: ModelConfiguration, - vocab_size: int, - target_flops: float, - train_steps: int, - seq_len: int, -) -> float: - """Solve for batch size needed to hit a target FLOP budget. - - Given: total_flops = 3 * flops_per_token * batch * steps * seq_len - Solve: batch = total_flops / (3 * flops_per_token * steps * seq_len) - - Args: - model_config: Model config with flops_per_token method. - vocab_size: Vocabulary size. - target_flops: Target total training FLOPs. - train_steps: Number of training steps. - seq_len: Sequence length. - - Returns: - Exact batch size (float) - caller decides how to round. - """ - flops_per_token = model_config.flops_per_token(vocab_size, seq_len) - return target_flops / (3 * flops_per_token * train_steps * seq_len) - - -def solve_for_train_steps( - model_config: ModelConfiguration, - vocab_size: int, - target_flops: float, - batch_size: int, - seq_len: int, -) -> float: - """Solve for training steps needed to hit a target FLOP budget. - - Given: total_flops = 3 * flops_per_token * batch * steps * seq_len - Solve: steps = total_flops / (3 * flops_per_token * batch * seq_len) - - Args: - model_config: Model config with flops_per_token method. - vocab_size: Vocabulary size. - target_flops: Target total training FLOPs. - batch_size: Training batch size. - seq_len: Sequence length. - - Returns: - Exact training steps (float) - caller decides how to round. - """ - flops_per_token = model_config.flops_per_token(vocab_size, seq_len) - return target_flops / (3 * flops_per_token * batch_size * seq_len) - - # ---------------- Training Config Generation ---------------- @@ -607,38 +523,3 @@ def predict_optimal_config( logger.info(f"Selected config: N={params:.2e}, tokens={best.tokens:.2e} (optimal: {optimal_tokens:.2e})") return best - - -def predict_optimal_configs_for_budgets( - scaling_fits: dict[str, ScalingFit], - target_budgets: list[float], - label: str, - recipe: ScalingRecipe, - seq_len: int = DEFAULT_SEQ_LEN, -) -> list[CandidateConfig]: - """Predict optimal configs for multiple target compute budgets. - - Args: - scaling_fits: Dict of {label: ScalingFit} from scaling ladder result. - target_budgets: List of target compute budgets in FLOPs. - label: Dataset/experiment label to use for scaling fit. - recipe: ScalingRecipe with architecture/hyperparameter settings (includes vocab_size). - seq_len: Sequence length for training. - - Returns: - List of CandidateConfig for each budget. - - Raises: - RuntimeError: If any budget cannot be predicted (to prevent silent failures). - """ - configs = [] - for budget in target_budgets: - config = predict_optimal_config(scaling_fits, budget, label, recipe, seq_len) - if config is None: - raise RuntimeError( - f"Failed to predict optimal config for budget {budget:.2e} FLOPs " - f"with label '{label}'. Check that the label exists in scaling_fits " - f"and that the budget is within a valid range." - ) - configs.append(config) - return configs diff --git a/lib/marin/src/marin/scaling_laws/scaling_ladder.py b/lib/marin/src/marin/scaling_laws/scaling_ladder.py deleted file mode 100644 index 1b2664f964..0000000000 --- a/lib/marin/src/marin/scaling_laws/scaling_ladder.py +++ /dev/null @@ -1,170 +0,0 @@ -# Copyright 2025 The Marin Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Scaling ladder: compute-optimal training runs based on IsoFLOP analysis. - -This module provides functions and configs for training models with compute-optimal -configurations derived from IsoFLOP analysis. -""" - -import json -import logging -import os -from dataclasses import dataclass -from datetime import timedelta - -import fsspec -import jmp -from fray.cluster import ResourceConfig -from haliax.partitioning import ResourceAxis -from levanter.checkpoint import CheckpointerConfig -from levanter.data.text import LMMixtureDatasetConfig -from levanter.main.train_lm import TrainLmConfig -from levanter.tracker.wandb import WandbConfig -from levanter.trainer import TrainerConfig -from levanter.utils.mesh import MeshConfig - -from marin.scaling_laws.isoflop_analysis import ( - ScalingFit, - ScalingRecipe, - predict_optimal_config, -) -from marin.scaling_laws.tpu_utils import pick_v5p_type -from marin.training.training import TrainLmOnPodConfig, run_levanter_train_lm - -logger = logging.getLogger(__name__) - - -@dataclass(frozen=True) -class ScalingLadderRungConfig: - """Configuration for one rung of the scaling ladder (one compute-optimal training run). - - This config references an IsoFLOP analysis output and specifies - the target compute budget. At runtime, the optimal config is loaded - from the analysis output. - - The ScalingRecipe handles all model-specific decisions (architecture, optimizer) - and owns the vocab_size (derived from the tokenizer choice). - """ - - analysis_output_path: str - """Path to the IsoFLOP analysis output directory.""" - - target_budget: float - """Target compute budget in FLOPs.""" - - label: str - """Dataset label to use for scaling fit (e.g., 'nemo', 'comma', 'dclm').""" - - tokenized: LMMixtureDatasetConfig - """Tokenized dataset for training (with validation sets already added).""" - - output_path: str - """Where to write training outputs.""" - - recipe: ScalingRecipe - """Scaling recipe that handles model/optimizer config building (includes vocab_size).""" - - seq_len: int = 4096 - """Sequence length for training.""" - - -def run_scaling_ladder_rung(config: ScalingLadderRungConfig) -> None: - """Run one rung of the scaling ladder (one compute-optimal training run). - - The candidate contains all training configuration: - - Model config: `candidate.model_config` - - Optimizer config: `candidate.optimizer_config` - - Training schedule: `candidate.batch_size`, `candidate.train_steps` - """ - result_path = os.path.join(config.analysis_output_path, "isoflop_analysis_result.json") - fs, _, _ = fsspec.get_fs_token_paths(result_path) - - with fs.open(result_path, "r") as f: - analysis_result = json.load(f) - - scaling_fits: dict[str, ScalingFit] = {} - for key, value in analysis_result["scaling_fits"].items(): - if len(value) != 2: - raise ValueError(f"Expected 2 scaling fit values for '{key}', got {len(value)}") - scaling_fits[key] = ScalingFit(float(value[0]), float(value[1])) - - candidate = predict_optimal_config( - scaling_fits=scaling_fits, - target_flops=config.target_budget, - label=config.label, - recipe=config.recipe, - seq_len=config.seq_len, - ) - - if candidate is None: - raise RuntimeError( - f"Could not find optimal config for budget {config.target_budget:.2e} and label '{config.label}'" - ) - - params = candidate.model_config.total_trainable_params(config.recipe.vocab_size) - logger.info( - f"Training with optimal config for {config.target_budget:.2e} FLOPs:\n" - f" params={params:.2e}\n" - f" tokens={candidate.tokens:.2e}" - ) - - model_cfg = candidate.model_config - optimizer_cfg = candidate.optimizer_config - batch_size = candidate.batch_size - train_steps = candidate.train_steps - tpu_type = pick_v5p_type(candidate, config.seq_len, config.recipe) - - train_config = TrainLmConfig( - data=config.tokenized, - trainer=TrainerConfig( - tracker=WandbConfig( - project="marin", - tags=[ - "optimal-training", - f"FLOPs={config.target_budget:.1e}", - f"label={config.label}", - f"N={params:.1e}", - ], - ), - mp=jmp.get_policy("p=f32,c=bfloat16"), - train_batch_size=batch_size, - num_train_steps=train_steps, - steps_per_eval=1000, - checkpointer=CheckpointerConfig( - save_interval=timedelta(minutes=10), - keep=[dict(every=5000)], - ), - mesh=MeshConfig( - # Special axes for MoEs - # TODO: this is actually bad and we should remove, but keeping for now - compute_mapping={ - "token": (ResourceAxis.REPLICA_DCN, ResourceAxis.REPLICA, ResourceAxis.DATA), - "token_repeat": (ResourceAxis.REPLICA_DCN, ResourceAxis.REPLICA, ResourceAxis.DATA), - } - ), - allow_nondivisible_batch_size=True, - ), - train_seq_len=config.seq_len, - model=model_cfg, - optimizer=optimizer_cfg, - ) - - full_config = TrainLmOnPodConfig( - train_config=train_config, - resources=ResourceConfig.with_tpu(tpu_type), - output_path=config.output_path, - ) - - run_levanter_train_lm(full_config) diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py index de3be0970f..48e5c11bcf 100644 --- a/tests/test_scaling_laws.py +++ b/tests/test_scaling_laws.py @@ -20,55 +20,17 @@ import jax.numpy as jnp -from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig -from levanter.models.qwen import Qwen3Config - from marin.scaling_laws.isoflop_analysis import ( DEFAULT_SEQ_LEN, CandidateConfig, - compute_training_flops, fit_scaling_laws, generate_training_configs, robust_quad_logx, - solve_for_batch_size, - solve_for_train_steps, ) # Import the concrete recipe and transform function from experiments from experiments.isoflop_sweep import Marin2025Recipe, parse_isoflop_run_name, transform_levanter_metrics -# --- FLOP computation tests --- - - -def test_flop_solvers_are_consistent(): - """Test that FLOP solvers correctly invert the FLOP calculation.""" - model_config = Qwen3Config( - max_seq_len=4096, - hidden_dim=768, - intermediate_dim=3072, - num_heads=12, - num_kv_heads=12, - num_layers=12, - rope=Llama3RotaryEmbeddingsConfig(), - ) - vocab_size = 128256 - seq_len = 4096 - - # Verify solve_for_batch_size inverts compute_training_flops - original_batch = 64 - train_steps = 10000 - target_flops = compute_training_flops(model_config, vocab_size, original_batch, train_steps, seq_len) - recovered_batch = solve_for_batch_size(model_config, vocab_size, target_flops, train_steps, seq_len) - assert abs(recovered_batch - original_batch) < 0.01 - - # Verify solve_for_train_steps inverts compute_training_flops - original_steps = 50000 - batch_size = 32 - target_flops = compute_training_flops(model_config, vocab_size, batch_size, original_steps, seq_len) - recovered_steps = solve_for_train_steps(model_config, vocab_size, target_flops, batch_size, seq_len) - assert abs(recovered_steps - original_steps) < 0.01 - - # --- Run name parsing tests --- @@ -106,13 +68,8 @@ def test_candidate_configs_within_tolerance(): if candidate is None: continue - achieved = compute_training_flops( - candidate.model_config, - recipe.vocab_size, - candidate.batch_size, - candidate.train_steps, - seq_len, - ) + # Compute training FLOPs inline: 3 * flops_per_token * batch * steps * seq_len + achieved = 3 * flops_per_token * candidate.batch_size * candidate.train_steps * seq_len relative_error = abs(achieved - budget) / budget assert relative_error <= flop_tolerance From e39d074a00051bf93c959cd05c74d75439ba0274 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 12 Jan 2026 17:25:23 -0800 Subject: [PATCH 73/79] Comment --- lib/marin/src/marin/scaling_laws/isoflop_analysis.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index 0ddb3bf66c..a4be76bf38 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -329,10 +329,8 @@ def generate_training_configs( def robust_quad_logx(x: jnp.ndarray, y: jnp.ndarray, delta: float = 1.0) -> tuple[float, float, float]: """Fit a robust quadratic in log10(x) space using Huber loss. - Log10 space is used because FLOP budgets and token counts span many orders of - magnitude (e.g., 1e18 to 1e21+). Fitting in linear space would be numerically - unstable and dominated by the largest values. Log space provides better - conditioning and more interpretable coefficients. + Log10 space is used because sweeps are defined in powers of 10 (scientific + notation like 1e18, 1e19, 3e19), so log10 produces evenly-spaced points. The Huber loss provides robustness to outliers compared to standard least squares. From d70596ace1a443d9cbfeb1aba50eaa42a61a1f09 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 12 Jan 2026 19:30:16 -0800 Subject: [PATCH 74/79] as_input_path and pre-estimate memory --- experiments/exp2166_scaling_ladder_analysis.py | 9 +++++---- experiments/isoflop_sweep.py | 3 ++- .../marin/scaling_laws/eval_metrics_reader.py | 2 +- lib/marin/src/marin/scaling_laws/tpu_utils.py | 17 ++++------------- 4 files changed, 12 insertions(+), 19 deletions(-) diff --git a/experiments/exp2166_scaling_ladder_analysis.py b/experiments/exp2166_scaling_ladder_analysis.py index c34aec7f1a..4861f1b27e 100644 --- a/experiments/exp2166_scaling_ladder_analysis.py +++ b/experiments/exp2166_scaling_ladder_analysis.py @@ -39,7 +39,7 @@ run_isoflop_analysis_step, ) from experiments.simple_train_config import SimpleTrainConfig -from marin.execution.executor import ExecutorStep, executor_main, output_path_of, this_output_path +from marin.execution.executor import ExecutorStep, executor_main, this_output_path from marin.processing.tokenize import add_validation_sets_to_mixture from marin.scaling_laws import ScalingFit, predict_optimal_config from marin.scaling_laws.tpu_utils import pick_v5p_type @@ -100,7 +100,8 @@ def run_optimal_training( f" tokens={candidate.tokens:.2e}" ) - tpu_type = pick_v5p_type(candidate, SEQ_LEN, MARIN_2025_RECIPE) + estimated_memory = MARIN_2025_RECIPE.estimate_memory_bytes(candidate, SEQ_LEN) + tpu_type = pick_v5p_type(estimated_memory) train_config = SimpleTrainConfig( resources=ResourceConfig.with_tpu(tpu_type), @@ -132,7 +133,7 @@ def run_optimal_training( name=f"{EXPERIMENT_NAME}-analysis", fn=run_isoflop_analysis_step, config=IsoFlopAnalysisConfig( - training_runs=tuple(output_path_of(r) for r in nemotron_training), + training_runs=tuple(r.as_input_name() for r in nemotron_training), output_path=this_output_path(), recipe=MARIN_2025_RECIPE, ), @@ -145,7 +146,7 @@ def run_optimal_training( step = ExecutorStep( name=f"{EXPERIMENT_NAME}-optimal-{budget:.0e}", fn=lambda b=budget: run_optimal_training( - analysis_output_path=output_path_of(analysis_step), + analysis_output_path=analysis_step.as_input_name(), target_budget=b, label=LABEL, ), diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index ac5873a549..182505c2ea 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -562,7 +562,8 @@ def create_isoflop_sweep_steps( # Create ExecutorSteps for each candidate configuration for candidate in candidates: model_config = candidate.model_config - tpu_type = pick_v5p_type(candidate, seq_len, recipe) + estimated_memory = recipe.estimate_memory_bytes(candidate, seq_len) + tpu_type = pick_v5p_type(estimated_memory) # Use local naming with architecture details for backward compatibility run_name = _format_run_name( diff --git a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py index 20cf3603f8..5e5cda3aab 100644 --- a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py +++ b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py @@ -22,7 +22,7 @@ import logging import os from collections.abc import Sequence -from dataclasses import dataclass +from typing import Protocol import fsspec import wandb diff --git a/lib/marin/src/marin/scaling_laws/tpu_utils.py b/lib/marin/src/marin/scaling_laws/tpu_utils.py index d12bfe8fcc..b40b49fd23 100644 --- a/lib/marin/src/marin/scaling_laws/tpu_utils.py +++ b/lib/marin/src/marin/scaling_laws/tpu_utils.py @@ -20,8 +20,6 @@ import math -from marin.scaling_laws.isoflop_analysis import CandidateConfig, ScalingRecipe - # ---------------- TPU v5p Hardware Constants ---------------- # These constants are specific to TPU v5p pods. @@ -35,17 +33,11 @@ """Available TPU v5p core configurations (slice sizes).""" -def pick_v5p_type( - candidate: CandidateConfig, - seq_len: int, - recipe: ScalingRecipe, -) -> str: - """Select the smallest TPU v5p slice that fits the model in float32. +def pick_v5p_type(estimated_memory_bytes: int) -> str: + """Select the smallest TPU v5p slice that fits the estimated memory. Args: - candidate: CandidateConfig with model_config and tokens. - seq_len: Sequence length. - recipe: ScalingRecipe for memory estimation (includes vocab_size). + estimated_memory_bytes: Estimated memory requirement in bytes. Returns: TPU slice name, e.g., "v5p-8" or "v5p-32". @@ -53,9 +45,8 @@ def pick_v5p_type( Raises: ValueError: If the model is too large for available v5p slices. """ - need_bytes = recipe.estimate_memory_bytes(candidate, seq_len) chip_bytes = HBM_PER_CHIP_GIB * 1024**3 - chips = math.ceil(need_bytes / chip_bytes) + chips = math.ceil(estimated_memory_bytes / chip_bytes) cores_req = chips * CORES_PER_CHIP valid = [c for c in V5P_CORE_OPTIONS if c >= cores_req] From 0fd52e459803abf6fcb594534a72ba772644cce2 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 12 Jan 2026 19:43:12 -0800 Subject: [PATCH 75/79] Remove Needless Protocl --- .../exp2166_scaling_ladder_analysis.py | 2 +- experiments/isoflop_sweep.py | 19 +++++---- .../marin/scaling_laws/eval_metrics_reader.py | 40 +++++++------------ 3 files changed, 27 insertions(+), 34 deletions(-) diff --git a/experiments/exp2166_scaling_ladder_analysis.py b/experiments/exp2166_scaling_ladder_analysis.py index 4861f1b27e..b0a212afd7 100644 --- a/experiments/exp2166_scaling_ladder_analysis.py +++ b/experiments/exp2166_scaling_ladder_analysis.py @@ -133,7 +133,7 @@ def run_optimal_training( name=f"{EXPERIMENT_NAME}-analysis", fn=run_isoflop_analysis_step, config=IsoFlopAnalysisConfig( - training_runs=tuple(r.as_input_name() for r in nemotron_training), + training_runs=[r.as_input_name() for r in nemotron_training], output_path=this_output_path(), recipe=MARIN_2025_RECIPE, ), diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index 182505c2ea..526309891c 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -22,7 +22,11 @@ import math import os import re -from collections.abc import Iterator + +import json +import fsspec + +from collections.abc import Iterator, Sequence from dataclasses import dataclass, replace from levanter.data.text import LMMixtureDatasetConfig @@ -42,6 +46,7 @@ from fray.cluster import ResourceConfig from marin.execution.executor import ExecutorStep, InputName, executor_main from marin.processing.tokenize import get_vocab_size_for_tokenizer, lm_mixture_data_config + from marin.scaling_laws import ( CandidateConfig, FitScalingLawsResult, @@ -412,7 +417,7 @@ class IsoFlopAnalysisConfig: This config is for use with ExecutorStep. """ - training_runs: tuple[str, ...] + training_runs: Sequence[str] """Training run output paths (executor resolves InputName to str at runtime).""" output_path: str @@ -449,12 +454,12 @@ def run_isoflop_analysis_step(config: IsoFlopAnalysisConfig) -> FitScalingLawsRe Returns: FitScalingLawsResult with fitted scaling laws """ - import json - - import fsspec - # Read raw records from training runs - raw_records = read_raw_records(config) + raw_records = read_raw_records( + training_runs=config.training_runs, + metrics_filename=config.metrics_filename, + wandb_entity_project=config.wandb_entity_project, + ) if not raw_records: logger.warning("No eval metrics found in training runs") diff --git a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py index 5e5cda3aab..82a8bb503f 100644 --- a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py +++ b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py @@ -14,15 +14,14 @@ """Base infrastructure for eval metrics analysis. -This module provides a config and utilities for analysis jobs that -read tracker_metrics.jsonl files from completed training runs. +This module provides utilities for analysis jobs that read tracker_metrics.jsonl +files from completed training runs. """ import json import logging import os from collections.abc import Sequence -from typing import Protocol import fsspec import wandb @@ -86,39 +85,28 @@ def _backfill_metrics_from_wandb( return False -@dataclass(frozen=True) -class EvalMetricsAnalysisConfig: - """Config for analyses that read eval metrics from training runs. - - The training_runs field creates blocking dependencies on the training jobs. - """ - - training_runs: Sequence[str] - """List of training run output paths (executor resolves InputName to str at runtime).""" - - metrics_filename: str = "tracker_metrics.jsonl" - """Name of the metrics file within each checkpoint directory.""" - - wandb_entity_project: str = f"{WANDB_ENTITY}/{WANDB_PROJECT}" - """WandB entity/project to query for backfill (format: 'entity/project').""" - - -def read_raw_records(config: EvalMetricsAnalysisConfig) -> list[dict]: +def read_raw_records( + training_runs: Sequence[str], + metrics_filename: str = "tracker_metrics.jsonl", + wandb_entity_project: str = f"{WANDB_ENTITY}/{WANDB_PROJECT}", +) -> list[dict]: """Read raw eval metrics from training runs. This is the shared utility that all analysis subtypes use to load metrics. It handles reading JSONL files and WandB backfill when files are missing. Args: - config: Analysis config with training_runs and backfill settings + training_runs: List of training run output paths. + metrics_filename: Name of the metrics file within each checkpoint directory. + wandb_entity_project: WandB entity/project to query for backfill (format: 'entity/project'). Returns: List of raw records, each containing config, summary, run_index, and run_path. """ all_records = [] - for i, run_path in enumerate(config.training_runs): - metrics_file = os.path.join(run_path, config.metrics_filename) + for i, run_path in enumerate(training_runs): + metrics_file = os.path.join(run_path, metrics_filename) fs, _, _ = fsspec.get_fs_token_paths(metrics_file) @@ -128,7 +116,7 @@ def read_raw_records(config: EvalMetricsAnalysisConfig) -> list[dict]: success = _backfill_metrics_from_wandb( checkpoint_path=run_path, metrics_file=metrics_file, - entity_project=config.wandb_entity_project, + entity_project=wandb_entity_project, ) if not success: raise RuntimeError( @@ -149,5 +137,5 @@ def read_raw_records(config: EvalMetricsAnalysisConfig) -> list[dict]: if not all_records: logger.warning("No eval metrics found in any training runs") - logger.info(f"Loaded {len(all_records)} evaluation records from {len(config.training_runs)} runs") + logger.info(f"Loaded {len(all_records)} evaluation records from {len(training_runs)} runs") return all_records From dbe47843c2a6e4f61c2b13fd63fc1f13c9889d66 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Mon, 12 Jan 2026 19:44:55 -0800 Subject: [PATCH 76/79] Read Eval Records --- experiments/isoflop_sweep.py | 4 ++-- lib/marin/src/marin/scaling_laws/eval_metrics_reader.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index 526309891c..232e06543c 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -57,7 +57,7 @@ pick_v5p_type, round_flops_to_bucket, ) -from marin.scaling_laws.eval_metrics_reader import read_raw_records +from marin.scaling_laws.eval_metrics_reader import read_eval_records from marin.utilities.wandb_utils import WANDB_ENTITY, WANDB_PROJECT logger = logging.getLogger(__name__) @@ -455,7 +455,7 @@ def run_isoflop_analysis_step(config: IsoFlopAnalysisConfig) -> FitScalingLawsRe FitScalingLawsResult with fitted scaling laws """ # Read raw records from training runs - raw_records = read_raw_records( + raw_records = read_eval_records( training_runs=config.training_runs, metrics_filename=config.metrics_filename, wandb_entity_project=config.wandb_entity_project, diff --git a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py index 82a8bb503f..c557b83c72 100644 --- a/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py +++ b/lib/marin/src/marin/scaling_laws/eval_metrics_reader.py @@ -85,7 +85,7 @@ def _backfill_metrics_from_wandb( return False -def read_raw_records( +def read_eval_records( training_runs: Sequence[str], metrics_filename: str = "tracker_metrics.jsonl", wandb_entity_project: str = f"{WANDB_ENTITY}/{WANDB_PROJECT}", From 0c7592cb38a62b3914197b3414a83345a53ca4d5 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Tue, 13 Jan 2026 18:00:09 -0800 Subject: [PATCH 77/79] This can run the full ladder now --- .../exp2166_scaling_ladder_analysis.py | 168 +++++++++++++----- experiments/isoflop_sweep.py | 61 +++++-- lib/marin/src/marin/scaling_laws/__init__.py | 2 - .../marin/scaling_laws/isoflop_analysis.py | 99 ++--------- lib/marin/src/marin/scaling_laws/tpu_utils.py | 2 +- tests/test_scaling_laws.py | 13 +- 6 files changed, 183 insertions(+), 162 deletions(-) diff --git a/experiments/exp2166_scaling_ladder_analysis.py b/experiments/exp2166_scaling_ladder_analysis.py index b0a212afd7..698417ddf9 100644 --- a/experiments/exp2166_scaling_ladder_analysis.py +++ b/experiments/exp2166_scaling_ladder_analysis.py @@ -23,14 +23,25 @@ 3. Optionally trains compute-optimal models at larger target budgets """ +import dataclasses import json import logging import os +from dataclasses import dataclass +from datetime import timedelta import fsspec +import jmp from fray.cluster import ResourceConfig - -from experiments.defaults import default_train, default_validation_sets +from haliax.partitioning import ResourceAxis +from levanter.checkpoint import CheckpointerConfig +from levanter.data.text import LMDatasetSourceConfig, LMMixtureDatasetConfig +from levanter.main import train_lm +from levanter.tracker.wandb import WandbConfig +from levanter.trainer import TrainerConfig +from levanter.utils.mesh import MeshConfig + +from experiments.defaults import default_validation_sets from experiments.isoflop_sweep import ( IsoFlopAnalysisConfig, MARIN_2025_RECIPE, @@ -38,11 +49,12 @@ nemotron_mix, run_isoflop_analysis_step, ) -from experiments.simple_train_config import SimpleTrainConfig +from experiments.llama import llama3_tokenizer from marin.execution.executor import ExecutorStep, executor_main, this_output_path -from marin.processing.tokenize import add_validation_sets_to_mixture +from marin.processing.tokenize import step_to_lm_mixture_component from marin.scaling_laws import ScalingFit, predict_optimal_config from marin.scaling_laws.tpu_utils import pick_v5p_type +from marin.training.training import TrainLmOnPodConfig, run_levanter_train_lm logger = logging.getLogger(__name__) @@ -50,27 +62,42 @@ nemotron_training, _ = MARIN_SCALING_SUITES["nemotron"] # --- Configuration --- -TARGET_BUDGETS: list[float] = [1e18, 3e18, 6e18, 1e19, 3e19, 6e19, 1e20] +TARGET_BUDGETS: list[float] = [1e18, 3e18, 6e18, 1e19, 3e19, 6e19, 1e20, 1e21, 1e22, 1e23, 1e24] EXPERIMENT_NAME = "exp2166-scaling-ladder-nemotron-validation" LABEL = "nemo-wider-depth-adapt" -TOKENIZER = "stanford-crfm/marin-tokenizer" SEQ_LEN = 4096 -# Add validation sets to the training mixture -nemotron_mix_with_validation = add_validation_sets_to_mixture(nemotron_mix, default_validation_sets(tokenizer=TOKENIZER)) +@dataclass(frozen=True) +class OptimalTrainingConfig: + """Config for training a compute-optimal model based on scaling law analysis.""" + + analysis_output_path: str + """Path to the analysis output containing scaling fits.""" + + target_budget: float + """Target compute budget in FLOPs.""" + + label: str + """Dataset/experiment label to use for scaling fit lookup.""" + + output_path: str + """Output path for checkpoints and logs.""" + + tokenized: LMMixtureDatasetConfig + """Tokenized dataset for training. Executor will resolve InputName and unwrap VersionedValue.""" + + validation_configs: dict[str, LMDatasetSourceConfig] | None = None + """Validation set configs. Passed through config so executor resolves InputName paths.""" -def run_optimal_training( - analysis_output_path: str, - target_budget: float, - label: str, -) -> ExecutorStep: - """Create an ExecutorStep for compute-optimal training at the given budget. - Loads scaling fits from the analysis output, predicts the optimal config, - and returns an ExecutorStep using default_train. +def run_optimal_training(config: OptimalTrainingConfig) -> None: + """Run compute-optimal training at the given budget. + + Reads scaling fits from analysis output, predicts optimal config, + builds training config, and runs training directly. """ - result_path = os.path.join(analysis_output_path, "isoflop_analysis_result.json") + result_path = os.path.join(config.analysis_output_path, "isoflop_analysis_result.json") fs, _, _ = fsspec.get_fs_token_paths(result_path) with fs.open(result_path, "r") as f: @@ -84,48 +111,92 @@ def run_optimal_training( candidate = predict_optimal_config( scaling_fits=scaling_fits, - target_flops=target_budget, - label=label, + target_flops=config.target_budget, + label=config.label, recipe=MARIN_2025_RECIPE, seq_len=SEQ_LEN, ) if candidate is None: - raise RuntimeError(f"Could not find optimal config for budget {target_budget:.2e} and label '{label}'") + raise RuntimeError( + f"Could not find optimal config for budget {config.target_budget:.2e} and label '{config.label}'" + ) params = candidate.model_config.total_trainable_params(MARIN_2025_RECIPE.vocab_size) logger.info( - f"Training with optimal config for {target_budget:.2e} FLOPs:\n" + f"Training with optimal config for {config.target_budget:.2e} FLOPs:\n" f" params={params:.2e}\n" f" tokens={candidate.tokens:.2e}" ) estimated_memory = MARIN_2025_RECIPE.estimate_memory_bytes(candidate, SEQ_LEN) tpu_type = pick_v5p_type(estimated_memory) - - train_config = SimpleTrainConfig( - resources=ResourceConfig.with_tpu(tpu_type), - train_batch_size=candidate.batch_size, - num_train_steps=candidate.train_steps, - learning_rate=candidate.optimizer_config.learning_rate, - optimizer_config=candidate.optimizer_config, + logger.info(f"Estimated memory: {estimated_memory / 1e9:.2f} GB, TPU type: {tpu_type}") + + # Build TrainLmConfig directly (like old run_scaling_ladder_rung) + # config.tokenized is already processed by executor's instantiate_config + data = config.tokenized + if config.validation_configs: + # Merge validation configs into the data mixture with weight 0 + new_configs = { + **data.configs, + **{name: cfg for name, cfg in config.validation_configs.items() if name not in data.configs}, + } + if isinstance(data.train_weights, dict): + new_weights = { + **data.train_weights, + **{name: 0.0 for name in config.validation_configs if name not in data.train_weights}, + } + else: + # Varying weights case + new_weights = [ + (step_idx, {**weights, **{name: 0.0 for name in config.validation_configs if name not in weights}}) + for step_idx, weights in data.train_weights + ] + data = dataclasses.replace(data, configs=new_configs, train_weights=new_weights) + + inner_config = train_lm.TrainLmConfig( + data=data, + trainer=TrainerConfig( + tracker=WandbConfig( + project="marin", + tags=[ + "optimal-training", + f"FLOPs={config.target_budget:.1e}", + f"label={config.label}", + f"N={params:.1e}", + ], + ), + mp=jmp.get_policy("p=f32,c=bfloat16"), + train_batch_size=candidate.batch_size, + num_train_steps=candidate.train_steps, + steps_per_eval=1000, + checkpointer=CheckpointerConfig( + save_interval=timedelta(minutes=10), + keep=[dict(every=5000)], + ), + mesh=MeshConfig( + compute_mapping={ + "token": (ResourceAxis.REPLICA_DCN, ResourceAxis.REPLICA, ResourceAxis.DATA), + "token_repeat": (ResourceAxis.REPLICA_DCN, ResourceAxis.REPLICA, ResourceAxis.DATA), + } + ), + allow_nondivisible_batch_size=True, + ), train_seq_len=SEQ_LEN, + model=candidate.model_config, + optimizer=candidate.optimizer_config, ) - return default_train( - name=f"{EXPERIMENT_NAME}-optimal-{target_budget:.0e}", - tokenized=nemotron_mix_with_validation, - model_config=candidate.model_config, - train_config=train_config, - tags=[ - "optimal-training", - f"FLOPs={target_budget:.1e}", - f"label={label}", - f"N={params:.1e}", - ], - use_default_validation=False, # Already added above + pod_config = TrainLmOnPodConfig( + train_config=inner_config, + resources=ResourceConfig.with_tpu(tpu_type), + output_path=config.output_path, ) + logger.info(f"Launching training with resources: {pod_config.resources}") + run_levanter_train_lm(pod_config) + # --- Step 1: IsoFLOP Analysis --- # Creates scaling law fits from the training runs @@ -139,18 +210,29 @@ def run_optimal_training( ), ) +# --- Create validation configs --- +# Convert validation TokenizerSteps to LMDatasetSourceConfig at module import time. +# This way instantiate_config resolves InputName paths before run_optimal_training runs. +validation_steps = default_validation_sets(tokenizer=llama3_tokenizer) +validation_configs = { + name: step_to_lm_mixture_component(step, include_raw_paths=False) for name, step in validation_steps.items() +} + # --- Step 2: Optimal Training Runs --- # Train compute-optimal models at each target budget optimal_runs: list[ExecutorStep] = [] for budget in TARGET_BUDGETS: step = ExecutorStep( name=f"{EXPERIMENT_NAME}-optimal-{budget:.0e}", - fn=lambda b=budget: run_optimal_training( + fn=run_optimal_training, + config=OptimalTrainingConfig( analysis_output_path=analysis_step.as_input_name(), - target_budget=b, + target_budget=budget, label=LABEL, + output_path=this_output_path(), + tokenized=nemotron_mix, + validation_configs=validation_configs, ), - config=None, ) optimal_runs.append(step) diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index 232e06543c..38d70db170 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -53,7 +53,6 @@ IsoFlopRecord, ScalingRecipe, fit_scaling_laws, - generate_training_configs, pick_v5p_type, round_flops_to_bucket, ) @@ -257,10 +256,13 @@ def vocab_size(self) -> int: # --- Constraints --- max_learning_rate: float = 0.01 min_batch_size: int = 8 + max_batch_size: int = 8192 + # max_params scales with sqrt(budget) above 3e20, with floor of 12B and ceiling of 1T + base_max_params: float = 12e9 + base_max_params_budget: float = 3e20 + global_max_params: float = 1e12 - # --- Search bounds for isoflop sweeps --- - min_hidden_pow: int = 9 - max_hidden_pow: int = 12 + # --- Search step sizes for isoflop sweeps --- small_budget_step_size: int = 128 large_budget_step_size: int = 256 budget_step_threshold: float = 9e18 @@ -287,6 +289,15 @@ def _get_step_size(self, budget: float) -> int: return self.large_budget_step_size return self.small_budget_step_size + def _max_params_for_budget(self, budget: float) -> float: + """Compute max_params as a function of budget. + + Returns base_max_params for budgets <= base_max_params_budget, + then scales with sqrt(budget) for larger budgets, capped at global_max_params. + """ + scaling = self.base_max_params * math.sqrt(budget / self.base_max_params_budget) + return min(max(self.base_max_params, scaling), self.global_max_params) + def _build_model_config_from_hidden_size(self, hidden_size: int, seq_len: int = DEFAULT_SEQ_LEN) -> LlamaConfig: """Build model config from hidden_size directly.""" if hidden_size % self.hidden_head_ratio != 0: @@ -333,12 +344,14 @@ def build_model_configs( budget: float, seq_len: int = DEFAULT_SEQ_LEN, ) -> Iterator[LlamaConfig]: - """Yield candidate model architectures for the given FLOP budget.""" + """Yield candidate model architectures for the given FLOP budget. + + Uses wide bounds (2**9 to 2**17) and relies on batch size filtering + in build_candidate_config to select valid configurations. + """ step_size = self._get_step_size(budget) - min_hidden = 2**self.min_hidden_pow - max_hidden = 2**self.max_hidden_pow - for hidden_size in range(min_hidden, max_hidden + 1, step_size): + for hidden_size in range(2**9, 2**17, step_size): yield self._build_model_config_from_hidden_size(hidden_size, seq_len) def build_candidate_config( @@ -366,8 +379,8 @@ def build_candidate_config( batch_size //= 2 lr = self._compute_learning_rate(batch_size, hidden_size) - # Return None if batch_size is below minimum - if batch_size < self.min_batch_size: + # Return None if batch_size is outside valid range + if batch_size < self.min_batch_size or batch_size > self.max_batch_size: return None # Compute train_steps to achieve target tokens @@ -401,6 +414,24 @@ def build_candidate_config( flops_budget=flops_budget, ) + def candidates_for_budget( + self, + budget: float, + seq_len: int = DEFAULT_SEQ_LEN, + ) -> Iterator[CandidateConfig]: + """Yield valid candidate training configs for the given FLOP budget.""" + max_params = self._max_params_for_budget(budget) + for model_config in self.build_model_configs(budget, seq_len): + # Skip models that exceed budget-dependent max_params + params = model_config.total_trainable_params(self.vocab_size) + if params > max_params: + continue + flops_per_token = model_config.flops_per_token(self.vocab_size, seq_len) + tokens = budget / (3 * flops_per_token) + candidate = self.build_candidate_config(model_config, tokens, budget, seq_len) + if candidate is not None: + yield candidate + MARIN_2025_RECIPE = Marin2025Recipe() """Default Marin scaling recipe.""" @@ -530,9 +561,6 @@ def create_isoflop_sweep_steps( ) -> tuple[list[ExecutorStep], list[CandidateConfig]]: """Create ExecutorSteps for an ISOFlop sweep. - This function creates ExecutorSteps directly in experiment code, using - `generate_training_configs()` from the library to compute configs. - Args: tokenized: Tokenized dataset to train on. experiment_name: Name suffix for the experiment (e.g., 'nemo', 'dclm'). @@ -546,12 +574,7 @@ def create_isoflop_sweep_steps( - steps: Training and evaluation ExecutorSteps for the sweep. - candidates: CandidateConfig for each training run with full config details. """ - # Generate complete training configs from the library - candidates = generate_training_configs( - budgets=budgets, - recipe=recipe, - seq_len=seq_len, - ) + candidates = [c for budget in budgets for c in recipe.candidates_for_budget(budget, seq_len)] # Base config for training runs (values overridden per-candidate) base_train_config = SimpleTrainConfig( diff --git a/lib/marin/src/marin/scaling_laws/__init__.py b/lib/marin/src/marin/scaling_laws/__init__.py index 618ef8e833..cfeb053bd0 100644 --- a/lib/marin/src/marin/scaling_laws/__init__.py +++ b/lib/marin/src/marin/scaling_laws/__init__.py @@ -25,7 +25,6 @@ ScalingFit, ScalingRecipe, fit_scaling_laws, - generate_training_configs, predict_optimal_config, round_flops_to_bucket, ) @@ -57,7 +56,6 @@ "create_isoflop_plot", "create_scaling_plot", "fit_scaling_laws", - "generate_training_configs", "pick_v5p_type", "predict_optimal_config", "round_flops_to_bucket", diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index a4be76bf38..2b25734ccc 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -26,12 +26,11 @@ Key functions: - fit_scaling_laws(records): Fit scaling laws from typed records - predict_optimal_config(): Predict optimal training config for a target budget -- generate_training_configs(): Generate training configs for an isoflop sweep """ import logging import math -from collections.abc import Iterator, Sequence +from collections.abc import Iterator from dataclasses import dataclass from typing import NamedTuple, Protocol @@ -174,13 +173,8 @@ class ScalingRecipe(Protocol): """Protocol defining the interface for scaling law recipes. Concrete implementations (e.g., Marin2025Recipe) should implement these - model-specific methods. Orchestration logic (generating training configs, - predicting optimal configs) is handled by library functions that use - these core methods. - - The recipe owns the vocab_size, which is derived from the tokenizer choice. - This ensures consistency and simplifies the API by not requiring vocab_size - to be threaded through every function call. + model-specific methods. The recipe owns the vocab_size, which is derived + from the tokenizer choice. """ name: str @@ -193,32 +187,17 @@ def estimate_memory_bytes(self, candidate: CandidateConfig, seq_len: int = DEFAU """Estimate memory usage in bytes for training a candidate configuration.""" ... - def build_model_configs( + def candidates_for_budget( self, budget: float, seq_len: int = DEFAULT_SEQ_LEN, - ) -> Iterator[ModelConfiguration]: - """Yield candidate model architectures for the given FLOP budget. - - A typical implementation will iterate over hidden sizes (the primary - architectural knob) and yield model configs for each feasible size. - """ - ... - - def build_candidate_config( - self, - model_config: ModelConfiguration, - tokens: float, - flops_budget: float, - seq_len: int = DEFAULT_SEQ_LEN, - ) -> CandidateConfig | None: - """Build complete training config for a model and token count. - - Solves for batch_size, computes optimizer hyperparameters (learning rate, - beta2, etc.), and returns a complete CandidateConfig. + ) -> Iterator[CandidateConfig]: + """Yield valid candidate training configs for the given FLOP budget. - Returns None if the configuration is invalid (e.g., batch_size < minimum - after learning rate constraints are applied). + This is the main entry point for generating training configurations. + Implementations should iterate over model architectures and yield + complete CandidateConfig objects with model, optimizer, batch size, + and training steps all configured. """ ... @@ -277,55 +256,6 @@ def round_flops_to_bucket(flops: float, base: float = 1.1) -> float: return base ** round(k) -# ---------------- Training Config Generation ---------------- - - -def generate_training_configs( - budgets: Sequence[float], - recipe: ScalingRecipe, - seq_len: int = DEFAULT_SEQ_LEN, -) -> list[CandidateConfig]: - """Generate training configurations for an isoflop sweep. - - For each FLOP budget: - 1. Gets candidate model architectures from the recipe - 2. Computes tokens needed to achieve the budget: tokens = budget / (3 * flops_per_token) - 3. Builds complete training configs via recipe.build_candidate_config() - 4. Filters out invalid configs (where build_candidate_config returns None) - - Args: - budgets: Sequence of FLOP budgets to generate configs for. - recipe: ScalingRecipe with architecture/hyperparameter settings. - seq_len: Sequence length for training. - - Returns: - List of CandidateConfig, each containing model_config, optimizer_config, - batch_size, train_steps, tokens, and flops_budget. - - Example: - >>> from marin.scaling_laws import generate_training_configs, DEFAULT_BUDGETS - >>> configs = generate_training_configs(budgets=DEFAULT_BUDGETS, recipe=recipe) - >>> for cfg in configs: - ... print(f"N={cfg.model_config.total_trainable_params(recipe.vocab_size):.1e}") - ... print(f"batch_size={cfg.batch_size}, steps={cfg.train_steps}") - ... print(f"lr={cfg.optimizer_config.learning_rate}") - """ - results: list[CandidateConfig] = [] - - for budget in budgets: - for model_config in recipe.build_model_configs(budget, seq_len): - # Compute tokens directly from budget - flops_per_token = model_config.flops_per_token(recipe.vocab_size, seq_len) - tokens = budget / (3 * flops_per_token) - - # Build complete training config (returns None if invalid) - candidate = recipe.build_candidate_config(model_config, tokens, budget, seq_len) - if candidate is not None: - results.append(candidate) - - return results - - def robust_quad_logx(x: jnp.ndarray, y: jnp.ndarray, delta: float = 1.0) -> tuple[float, float, float]: """Fit a robust quadratic in log10(x) space using Huber loss. @@ -499,14 +429,7 @@ def predict_optimal_config( logger.info(f"Predicted optimal tokens for {target_flops:.2e} FLOPs: {optimal_tokens:.2e}") - # Build candidates using the new API - candidates: list[CandidateConfig] = [] - for model_config in recipe.build_model_configs(target_flops, seq_len): - flops_per_token = model_config.flops_per_token(recipe.vocab_size, seq_len) - tokens = target_flops / (3 * flops_per_token) - candidate = recipe.build_candidate_config(model_config, tokens, target_flops, seq_len) - if candidate is not None: - candidates.append(candidate) + candidates = list(recipe.candidates_for_budget(target_flops, seq_len)) if not candidates: logger.warning(f"No valid candidates found for budget {target_flops:.2e}") diff --git a/lib/marin/src/marin/scaling_laws/tpu_utils.py b/lib/marin/src/marin/scaling_laws/tpu_utils.py index b40b49fd23..1a20cbf7c1 100644 --- a/lib/marin/src/marin/scaling_laws/tpu_utils.py +++ b/lib/marin/src/marin/scaling_laws/tpu_utils.py @@ -29,7 +29,7 @@ CORES_PER_CHIP = 2 """Number of cores per TPU v5p chip.""" -V5P_CORE_OPTIONS = [8, 16, 32, 128, 256, 512] +V5P_CORE_OPTIONS = [8, 16, 32, 128, 256, 512, 1024, 2048] """Available TPU v5p core configurations (slice sizes).""" diff --git a/tests/test_scaling_laws.py b/tests/test_scaling_laws.py index 48e5c11bcf..1ac4c29688 100644 --- a/tests/test_scaling_laws.py +++ b/tests/test_scaling_laws.py @@ -24,7 +24,6 @@ DEFAULT_SEQ_LEN, CandidateConfig, fit_scaling_laws, - generate_training_configs, robust_quad_logx, ) @@ -93,7 +92,7 @@ def test_robust_quad_logx_fits_quadratic(): # --- Snapshot test for config generation --- -# Snapshot of expected output for generate_training_configs with budget=3e18 training FLOPs. +# Snapshot of expected output for candidates_for_budget with budget=3e18 training FLOPs. EXPECTED_ISOFLOP_CONFIGS_3E18 = [ {"batch_size": 32, "train_steps": 32844, "flops_budget": 3e18}, {"batch_size": 16, "train_steps": 46274, "flops_budget": 3e18}, @@ -103,22 +102,18 @@ def test_robust_quad_logx_fits_quadratic(): ] -def test_generate_training_configs_snapshot(): - """Snapshot test: verify generate_training_configs produces expected configs. +def test_candidates_for_budget_snapshot(): + """Snapshot test: verify candidates_for_budget produces expected configs. This ensures reproducibility of the config generation algorithm. """ recipe = Marin2025Recipe() - result = generate_training_configs( - budgets=(3e18,), - recipe=recipe, - ) + result = list(recipe.candidates_for_budget(budget=3e18)) assert len(result) == len(EXPECTED_ISOFLOP_CONFIGS_3E18) for i, (candidate, expected) in enumerate(zip(result, EXPECTED_ISOFLOP_CONFIGS_3E18, strict=True)): assert isinstance(candidate, CandidateConfig) - # batch_size and train_steps are now directly on the candidate assert candidate.batch_size == expected["batch_size"], f"Config {i}: batch_size mismatch" assert candidate.train_steps == expected["train_steps"], f"Config {i}: train_steps mismatch" assert candidate.flops_budget == expected["flops_budget"], f"Config {i}: flops_budget mismatch" From 20427d4f34e8f5d85df1eadc9117d337354676f5 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Tue, 13 Jan 2026 18:12:53 -0800 Subject: [PATCH 78/79] Mild Simplify --- experiments/exp2166_scaling_ladder_analysis.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/experiments/exp2166_scaling_ladder_analysis.py b/experiments/exp2166_scaling_ladder_analysis.py index 698417ddf9..df9e794628 100644 --- a/experiments/exp2166_scaling_ladder_analysis.py +++ b/experiments/exp2166_scaling_ladder_analysis.py @@ -23,11 +23,10 @@ 3. Optionally trains compute-optimal models at larger target budgets """ -import dataclasses import json import logging import os -from dataclasses import dataclass +from dataclasses import dataclass, replace from datetime import timedelta import fsspec @@ -140,7 +139,7 @@ def run_optimal_training(config: OptimalTrainingConfig) -> None: # Merge validation configs into the data mixture with weight 0 new_configs = { **data.configs, - **{name: cfg for name, cfg in config.validation_configs.items() if name not in data.configs}, + **{k: v for k, v in config.validation_configs.items() if k not in data.configs}, } if isinstance(data.train_weights, dict): new_weights = { @@ -153,7 +152,7 @@ def run_optimal_training(config: OptimalTrainingConfig) -> None: (step_idx, {**weights, **{name: 0.0 for name in config.validation_configs if name not in weights}}) for step_idx, weights in data.train_weights ] - data = dataclasses.replace(data, configs=new_configs, train_weights=new_weights) + data = replace(data, configs=new_configs, train_weights=new_weights) inner_config = train_lm.TrainLmConfig( data=data, From 705a47aaa3fa2cbd3e2f709f2227e6796581896f Mon Sep 17 00:00:00 2001 From: Helw150 Date: Sat, 17 Jan 2026 13:20:19 -0800 Subject: [PATCH 79/79] More Grad Accum --- .../exp2166_scaling_ladder_analysis.py | 52 +++++++++++++++---- experiments/isoflop_sweep.py | 40 +++++++++++--- .../marin/scaling_laws/isoflop_analysis.py | 2 +- lib/marin/src/marin/scaling_laws/tpu_utils.py | 2 +- 4 files changed, 77 insertions(+), 19 deletions(-) diff --git a/experiments/exp2166_scaling_ladder_analysis.py b/experiments/exp2166_scaling_ladder_analysis.py index df9e794628..1ff51d1742 100644 --- a/experiments/exp2166_scaling_ladder_analysis.py +++ b/experiments/exp2166_scaling_ladder_analysis.py @@ -52,9 +52,10 @@ from marin.execution.executor import ExecutorStep, executor_main, this_output_path from marin.processing.tokenize import step_to_lm_mixture_component from marin.scaling_laws import ScalingFit, predict_optimal_config -from marin.scaling_laws.tpu_utils import pick_v5p_type +from marin.scaling_laws.tpu_utils import pick_v5p_type, HBM_PER_CHIP_GIB from marin.training.training import TrainLmOnPodConfig, run_levanter_train_lm +logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Get training steps from the isoflop sweep @@ -65,6 +66,7 @@ EXPERIMENT_NAME = "exp2166-scaling-ladder-nemotron-validation" LABEL = "nemo-wider-depth-adapt" SEQ_LEN = 4096 +MAX_TPU_TYPE = "v5p-64" # Cap TPU size; use gradient accumulation for larger models @dataclass(frozen=True) @@ -122,15 +124,46 @@ def run_optimal_training(config: OptimalTrainingConfig) -> None: ) params = candidate.model_config.total_trainable_params(MARIN_2025_RECIPE.vocab_size) - logger.info( - f"Training with optimal config for {config.target_budget:.2e} FLOPs:\n" - f" params={params:.2e}\n" - f" tokens={candidate.tokens:.2e}" + estimated_memory = MARIN_2025_RECIPE.estimate_memory_bytes(candidate) + + # Compute TPU type and gradient accumulation settings + max_cores = int(MAX_TPU_TYPE.split("-")[1]) + num_chips = max_cores // 2 + max_memory = num_chips * HBM_PER_CHIP_GIB * 1024**3 + + per_device_parallelism: int | None = None + if estimated_memory <= max_memory: + # Fits without gradient accumulation + tpu_type = pick_v5p_type(estimated_memory) + else: + # Need gradient accumulation to fit in MAX_TPU_TYPE + tpu_type = MAX_TPU_TYPE + microbatch_size = candidate.batch_size + while (microbatch_size / candidate.batch_size) * estimated_memory > max_memory: + microbatch_size //= 2 + if microbatch_size < num_chips: + raise ValueError( + f"Cannot fit model in {MAX_TPU_TYPE}: need microbatch >= {num_chips}, got {microbatch_size}" + ) + per_device_parallelism = microbatch_size // num_chips + + print( + f"Optimal config for {config.target_budget:.2e} FLOPs:\n" + f" hidden_dim={candidate.model_config.hidden_dim}, layers={candidate.model_config.num_layers}\n" + f" params={params:.2e}, tokens={candidate.tokens:.2e}\n" + f" batch_size={candidate.batch_size}, train_steps={candidate.train_steps}\n" + f" estimated_memory={estimated_memory / 1e9:.2f} GB -> {tpu_type}\n" + f" per_device_parallelism={per_device_parallelism or 'None (no grad accum)'}" ) - estimated_memory = MARIN_2025_RECIPE.estimate_memory_bytes(candidate, SEQ_LEN) - tpu_type = pick_v5p_type(estimated_memory) - logger.info(f"Estimated memory: {estimated_memory / 1e9:.2f} GB, TPU type: {tpu_type}") + # For very large models, use aggressive gradient checkpointing to reduce memory + # Following exp1295_32b.py pattern: offload only carries, not inputs + model_config = candidate.model_config + if config.target_budget >= 1e21: + from haliax import ScanCheckpointPolicy + + model_config = replace(model_config, gradient_checkpointing=ScanCheckpointPolicy(save_carries="offload")) + logger.info("Using offload carries gradient checkpointing for large model") # Build TrainLmConfig directly (like old run_scaling_ladder_rung) # config.tokenized is already processed by executor's instantiate_config @@ -168,6 +201,7 @@ def run_optimal_training(config: OptimalTrainingConfig) -> None: ), mp=jmp.get_policy("p=f32,c=bfloat16"), train_batch_size=candidate.batch_size, + per_device_parallelism=per_device_parallelism if per_device_parallelism else -1, num_train_steps=candidate.train_steps, steps_per_eval=1000, checkpointer=CheckpointerConfig( @@ -183,7 +217,7 @@ def run_optimal_training(config: OptimalTrainingConfig) -> None: allow_nondivisible_batch_size=True, ), train_seq_len=SEQ_LEN, - model=candidate.model_config, + model=model_config, optimizer=candidate.optimizer_config, ) diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index 38d70db170..961b051e65 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -322,21 +322,45 @@ def _build_model_config_from_hidden_size(self, hidden_size: int, seq_len: int = def estimate_memory_bytes( self, candidate: CandidateConfig, - seq_len: int = DEFAULT_SEQ_LEN, optim_mult: int = 3, dtype_size: int = 4, fudge_factor: float = 2.0, ) -> int: - """Estimate float32 memory usage in bytes for training.""" + """Estimate memory usage in bytes for training. + + Accounts for: + - Parameters + optimizer state (master weights, momentum, variance) + - Activation memory including attention O(seq²) term + - Embedding table memory + """ model_config = candidate.model_config batch_size = candidate.batch_size - + seq_len = model_config.max_seq_len + hidden = model_config.hidden_dim + intermediate = getattr(model_config, "intermediate_dim", hidden * self.mlp_ratio) + layers = model_config.num_layers + # Parameters + optimizer (master weights + momentum + variance in fp32) param_count = model_config.total_trainable_params(self.vocab_size) param_bytes = param_count * optim_mult * dtype_size - act_bytes = (batch_size * model_config.max_seq_len) * ( - (model_config.hidden_dim * model_config.num_layers) + self.vocab_size * fudge_factor - ) - total_bytes = param_bytes + act_bytes + + # Activation memory per layer (bf16 = 2 bytes) + # - Hidden states: batch * seq * hidden + # - Attention Q/K/V/O: batch * seq * hidden * 4 (flash attention is O(seq), not O(seq²)) + # - MLP intermediate: batch * seq * intermediate + hidden_act = batch_size * seq_len * hidden * 2 + attn_act = batch_size * seq_len * hidden * 4 * 2 # Q, K, V, output tensors (flash attn) + mlp_act = batch_size * seq_len * intermediate * 2 + per_layer_act = hidden_act + attn_act + mlp_act + + # Activation memory scales with layers. Even with gradient checkpointing, + # we need significant memory for recomputation and gradient storage. + # Empirically, memory scales roughly as layers * 0.75 for large models. + act_bytes = per_layer_act * max(layers * 3 // 4, 4) + + # Embedding table (often not sharded well) + embed_bytes = self.vocab_size * hidden * 2 + + total_bytes = param_bytes + act_bytes + embed_bytes return int(total_bytes * fudge_factor) def build_model_configs( @@ -590,7 +614,7 @@ def create_isoflop_sweep_steps( # Create ExecutorSteps for each candidate configuration for candidate in candidates: model_config = candidate.model_config - estimated_memory = recipe.estimate_memory_bytes(candidate, seq_len) + estimated_memory = recipe.estimate_memory_bytes(candidate) tpu_type = pick_v5p_type(estimated_memory) # Use local naming with architecture details for backward compatibility diff --git a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py index 2b25734ccc..a5d57cc50e 100644 --- a/lib/marin/src/marin/scaling_laws/isoflop_analysis.py +++ b/lib/marin/src/marin/scaling_laws/isoflop_analysis.py @@ -183,7 +183,7 @@ class ScalingRecipe(Protocol): vocab_size: int """Vocabulary size for the tokenizer used with this recipe.""" - def estimate_memory_bytes(self, candidate: CandidateConfig, seq_len: int = DEFAULT_SEQ_LEN) -> int: + def estimate_memory_bytes(self, candidate: CandidateConfig) -> int: """Estimate memory usage in bytes for training a candidate configuration.""" ... diff --git a/lib/marin/src/marin/scaling_laws/tpu_utils.py b/lib/marin/src/marin/scaling_laws/tpu_utils.py index 1a20cbf7c1..aae9459aec 100644 --- a/lib/marin/src/marin/scaling_laws/tpu_utils.py +++ b/lib/marin/src/marin/scaling_laws/tpu_utils.py @@ -29,7 +29,7 @@ CORES_PER_CHIP = 2 """Number of cores per TPU v5p chip.""" -V5P_CORE_OPTIONS = [8, 16, 32, 128, 256, 512, 1024, 2048] +V5P_CORE_OPTIONS = [8, 16, 32, 64, 128, 256, 512, 1024, 2048] """Available TPU v5p core configurations (slice sizes)."""