diff --git a/.gitignore b/.gitignore index 2bd422eca7..5f4a0296dc 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ .git CLAUDE.md +GEMINI.md logs/ diff --git a/experiments/isoflop_sweep.py b/experiments/isoflop_sweep.py index ca68974aae..e7ddb1811c 100644 --- a/experiments/isoflop_sweep.py +++ b/experiments/isoflop_sweep.py @@ -96,12 +96,7 @@ def pick_v5p_type( seq_len: int, vocab: int, ) -> str: - """ - Select the smallest TPU v5p slice that fits the model in float32. - - Returns: - - TPU slice name, e.g., "v5p-8" or "v5p-32" - """ + """Select the smallest TPU v5p slice that fits the model in float32.""" param_count = compute_num_parameters(config, vocab) need_bytes = estimate_bytes(param_count, hidden, layers, batch, seq_len, vocab) chip_bytes = HBM_PER_CHIP_GIB * 1024**3 diff --git a/experiments/plantcad/exp2101_plantcad_isoflop_analysis.py b/experiments/plantcad/exp2101_plantcad_isoflop_analysis.py new file mode 100644 index 0000000000..0ded599f84 --- /dev/null +++ b/experiments/plantcad/exp2101_plantcad_isoflop_analysis.py @@ -0,0 +1,1101 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import logging +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from rich.console import Console +from rich.table import Table +from scipy.interpolate import griddata +import wandb + +RUN_VERSION = "2.13" +RUN_PREFIX = f"plantcad_isoflop_v{RUN_VERSION}" +RESULT_PATH = f"experiments/plantcad/results/v{RUN_VERSION}" +EXPORT_DPI = 300 +DEFAULT_ARCH = "qwen" + +# When True, use non-embedding params (from params_nonembed tag) instead of total params +NON_EMBED_PARAMS_ONLY = False + +console = Console(record=True) +logger = logging.getLogger(__name__) + + +def setup_logging(log_path: Path) -> None: + """Configure logging to both console and file.""" + log_path.parent.mkdir(parents=True, exist_ok=True) + + # Clear any existing handlers + logger.handlers.clear() + logger.setLevel(logging.INFO) + + # Console handler + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + console_handler.setFormatter(logging.Formatter("%(message)s")) + + # File handler + file_handler = logging.FileHandler(log_path, mode="w") + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter("%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")) + + logger.addHandler(console_handler) + logger.addHandler(file_handler) + + +def filter_to_finished_runs(df: pd.DataFrame, allow_crashed: bool = False) -> pd.DataFrame: + """Filter dataframe to include finished runs and nearly-complete crashed runs. + + Includes runs where: + - state == "finished", OR + - state == "crashed" AND run_progress > 0.999 + """ + is_finished = df["state"] == "finished" + if allow_crashed: + is_nearly_complete_crash = (df["state"] == "crashed") & (df["run_progress"] > 0.999) + return df[is_finished | is_nearly_complete_crash] + else: + return df[is_finished] + + +EXPLODED_RUNS: dict[str, list[str]] = {} +EXPLODED_BUDGETS: dict[str, list[float]] = { + "1.9": [1.0e16], + "1.10": [3.3e16], + "1.12": [3.3e16, 5.2e16], + "2.2": [8.0e16], + "2.4": [2.0e16], + "2.5": [4.0e16, 7.5e16], + "2.6": [1.0e16, 1.9e16], + "2.7": [3.2e17], + "2.9": [6.4e17, 1.2e18], + "2.12": [1.6e17, 3e17], + "2.13": [3.2e17], +} + + +def filter_exploded_runs(df: pd.DataFrame) -> pd.DataFrame: + """Filter out runs where training exploded (by run name or budget).""" + exploded_runs = EXPLODED_RUNS.get(RUN_VERSION, []) + exploded_budgets = EXPLODED_BUDGETS.get(RUN_VERSION, []) + + # Filter by run name + run_mask = df["run_name"].isin(exploded_runs) + for run_name in df.loc[run_mask, "run_name"]: + logger.warning(f"Filtering exploded run: {run_name}") + + # Filter by budget + budget_mask = df["flops_budget"].isin(exploded_budgets) + n_budget_filtered = budget_mask.sum() + if n_budget_filtered > 0: + filtered_budgets = df.loc[budget_mask, "flops_budget"].unique() + for budget in filtered_budgets: + budget_runs = df.loc[df["flops_budget"] == budget, "run_name"].tolist() + logger.warning(f"Filtering {len(budget_runs)} runs at exploded budget {budget:.1e}: {budget_runs}") + + return df[~run_mask & ~budget_mask] + + +def save_figure(fig, output_path: str) -> None: + """Save figure as both PNG and PDF at EXPORT_DPI resolution.""" + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Save PNG + fig.savefig(output_path, dpi=EXPORT_DPI, bbox_inches="tight") + logger.info(f"Saved plot to {output_path}") + + # Save PDF + pdf_path = output_path.with_suffix(".pdf") + fig.savefig(pdf_path, dpi=EXPORT_DPI, bbox_inches="tight") + logger.info(f"Saved plot to {pdf_path}") + + +def log_run_object(run, run_idx): + """Log a run object as JSON to show available data.""" + logger.info(f"\n{'=' * 80}") + logger.info(f"RUN {run_idx + 1}: {run.name}") + logger.info(f"{'=' * 80}") + run_dict = { + "id": run.id, + "name": run.name, + "state": run.state, + "created_at": str(run.created_at), + "tags": run.tags, + "config": dict(run.config), + "summary": dict(run.summary), + } + logger.info(json.dumps(run_dict, indent=2, default=str)) + logger.info(f"{'=' * 80}\n") + + +def fetch_plantcad_runs(show_wandb_runs: bool = False): + """Fetch plantcad isoflop runs and extract metrics/tags into a dataframe.""" + api = wandb.Api(timeout=30) + # Note: Results from the first run (plantcad_isoflop_01) are available at: + # https://github.com/marin-community/marin/issues/2101#issuecomment-3581675724 + runs = api.runs( + "marin", + filters={"display_name": {"$regex": f"^{RUN_PREFIX}"}}, + ) + + data = [] + for idx, run in enumerate(runs): + # Log first 2 runs in detail + if show_wandb_runs and idx < 2: + log_run_object(run, idx) + + # Parse tags like "batch_size=32" + tags_dict = {} + for tag in run.tags: + if "=" in tag: + key, value = tag.split("=", 1) + try: + # Try to convert to appropriate type + if "." in value or "e+" in value or "e-" in value: + tags_dict[key] = float(value) + else: + tags_dict[key] = int(value) + except ValueError: + tags_dict[key] = value + + # Calculate execution time + start_time = pd.to_datetime(run.created_at) if run.created_at else None + stop_time = pd.to_datetime(run.summary.get("_timestamp"), unit="s") if run.summary.get("_timestamp") else None + + # Handle timezone differences + if start_time and stop_time: + if start_time.tzinfo and not stop_time.tzinfo: + stop_time = stop_time.tz_localize("UTC") + elif stop_time.tzinfo and not start_time.tzinfo: + start_time = start_time.tz_localize("UTC") + duration = (stop_time - start_time).total_seconds() + else: + duration = None + + if "eval/plantcad2/loss" in run.summary: + eval_metric = "eval/plantcad2/loss" + elif "eval/dclm_baseline/loss" in run.summary: + eval_metric = "eval/dclm_baseline/loss" + else: + logger.warning(f"No eval metric found in run {run.name}") + eval_metric = None + + # Determine which param count to use based on flag + if NON_EMBED_PARAMS_ONLY: + params_value = tags_dict.get("params_nonembed") + else: + params_value = tags_dict.get("params") + + row = { + "run_name": run.name, + "state": run.state, + "start_time": start_time, + "stop_time": stop_time, + "duration_sec": duration, + # Metrics + "eval_loss": run.summary.get(eval_metric) if eval_metric else None, + "train_loss": run.summary.get("train/loss"), + "total_gflops": run.summary.get("throughput/total_gflops"), + "total_tokens": run.summary.get("throughput/total_tokens"), + "run_progress": run.summary.get("run_progress"), + # Tags + "architecture": tags_dict.get("architecture"), + "batch_size": tags_dict.get("batch_size"), + "flops_budget": tags_dict.get("flops_budget"), + "hidden_size": tags_dict.get("hidden_size"), + "num_layers": tags_dict.get("num_layers"), + "params": params_value, + "steps": tags_dict.get("steps"), + "tokens": tags_dict.get("tokens"), + "tpu": tags_dict.get("tpu"), + "epochs": tags_dict.get("epochs"), + # Config + "hf_save_path": run.config.get("hf_save_path"), + } + data.append(row) + + return pd.DataFrame(data) + + +def save_runs(df, output_path=f"{RESULT_PATH}/plantcad_isoflops.csv"): + """Save dataframe to CSV file.""" + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(output_path, index=False) + logger.info(f"Saved {len(df)} runs to {output_path}") + + +def validate_runs(df): + """Validate that rows are unique by key columns.""" + key_cols = ["architecture", "flops_budget", "tokens", "params", "epochs"] + duplicates = df[df.duplicated(subset=key_cols, keep=False)] + if not duplicates.empty: + logger.warning(f"Found {len(duplicates)} duplicate rows by {key_cols}:") + logger.warning(duplicates[["run_name", *key_cols]].to_string()) + else: + logger.info(f"Validation passed: rows are unique by {key_cols}") + + # Check that total_tokens matches tokens (within 0.1% tolerance) for finished runs only + df_finished = filter_to_finished_runs(df) + tolerance = 0.001 * df_finished["tokens"] + mismatch_mask = abs(df_finished["total_tokens"] - df_finished["tokens"]) > tolerance + if mismatch_mask.any(): + mismatches = df_finished.loc[mismatch_mask, ["run_name", "tokens", "total_tokens"]].copy() + mismatches["diff"] = mismatches["total_tokens"] - mismatches["tokens"] + mismatches["pct_diff"] = (mismatches["diff"] / mismatches["tokens"] * 100).round(2) + raise AssertionError(f"total_tokens != tokens for {mismatch_mask.sum()} runs:\n{mismatches.to_string()}") + + +def summarize_runs(df): + """Print formatted summary tables using rich.""" + gflops_to_flops = 1e9 + + # Run summary table + run_summary_cols = [ + "run_name", + "state", + "flops_budget", + "architecture", + "params", + "tokens", + "epochs", + "eval_loss", + "run_progress", + ] + summary_table = Table(title="Run Summary", show_header=True, header_style="bold cyan") + for col in run_summary_cols: + summary_table.add_column(col) + summary = df[run_summary_cols].copy() + for _, row in summary.sort_values(["flops_budget", "architecture", "epochs"]).iterrows(): + summary_table.add_row(*[str(v) if pd.notna(v) else "" for v in row]) + console.print(summary_table) + + # Checkpoint summary table - best runs per (flops_budget, architecture, epochs) + ckpt_cols = ["run_name", "flops_budget", "architecture", "epochs", "eval_loss", "hf_save_path"] + group_cols = ["flops_budget", "architecture", "epochs"] + # Find min eval_loss per group and keep all rows matching that min + df_with_min = df.merge( + df.groupby(group_cols)["eval_loss"].min().reset_index().rename(columns={"eval_loss": "min_eval_loss"}), + on=group_cols, + ) + best_runs = df_with_min[df_with_min["eval_loss"] == df_with_min["min_eval_loss"]][ckpt_cols].copy() + ckpt_table = Table( + title="Checkpoint Summary (Best per Budget/Arch/Epochs)", show_header=True, header_style="bold cyan" + ) + for col in ckpt_cols: + ckpt_table.add_column(col) + for _, row in best_runs.sort_values(group_cols).iterrows(): + ckpt_table.add_row(*[str(v) if pd.notna(v) else "" for v in row]) + console.print(ckpt_table) + + # FLOPs summary table + flops_table = Table(title="FLOPs Summary", show_header=True, header_style="bold cyan") + flops_table.add_column("Compute Budget", style="bold") + flops_table.add_column("Runs", justify="right") + flops_table.add_column("Budget (FLOPs)", justify="right") + flops_table.add_column("Actual (FLOPs)", justify="right") + + for budget, grp in df.groupby("flops_budget", sort=True): + flops_table.add_row( + f"{budget:.1e}", + str(len(grp)), + f"{grp['flops_budget'].sum():.3e}", + f"{grp['total_gflops'].sum() * gflops_to_flops:.3e}", + ) + flops_table.add_section() + flops_table.add_row( + "[bold]Total[/bold]", + f"[bold]{len(df)}[/bold]", + f"[bold]{df['flops_budget'].sum():.3e}[/bold]", + f"[bold]{df['total_gflops'].sum() * gflops_to_flops:.3e}[/bold]", + ) + console.print(flops_table) + + +def visualize_loss_by_token_count(df, metric="eval_loss", output_path=f"{RESULT_PATH}/plantcad_loss_by_tokens.png"): + """Plot loss vs tokens, colored by budget, faceted by architecture (cols) and epochs (rows).""" + required_cols = [metric, "tokens", "architecture", "flops_budget", "epochs"] + df_clean = filter_to_finished_runs(df).dropna(subset=required_cols) + + if df_clean.empty: + logger.warning(f"No finished runs with required columns {required_cols}. Skipping visualization.") + return + + architectures = sorted(df_clean["architecture"].unique()) + budgets = sorted(df_clean["flops_budget"].unique()) + unique_epochs = sorted(df_clean["epochs"].unique()) + + # Create budget colormap + cmap = plt.get_cmap("viridis") + budget_colors = {b: cmap(i / max(1, len(budgets) - 1)) for i, b in enumerate(budgets)} + + # Get global x-limits and per-epoch y-limits + x_min, x_max = df_clean["tokens"].min(), df_clean["tokens"].max() + epoch_ylims = {} + for epoch in unique_epochs: + epoch_data = df_clean[df_clean["epochs"] == epoch] + y_min, y_max = epoch_data[metric].min(), epoch_data[metric].max() + y_padding = (y_max - y_min) * 0.1 + epoch_ylims[epoch] = (y_min - y_padding, y_max + y_padding) + + fig, axes = plt.subplots( + len(unique_epochs), len(architectures), figsize=(5 * len(architectures), 3 * len(unique_epochs)), squeeze=False + ) + + for ei, epoch in enumerate(unique_epochs): + for ai, arch in enumerate(architectures): + ax = axes[ei, ai] + for budget in budgets: + data = df_clean[ + (df_clean["architecture"] == arch) + & (df_clean["flops_budget"] == budget) + & (df_clean["epochs"] == epoch) + ].sort_values("tokens") + if data.empty: + continue + color = budget_colors[budget] + ax.plot(data["tokens"], data[metric], alpha=0.7, linewidth=1.5, color=color) + ax.scatter(data["tokens"], data[metric], alpha=0.8, color=color, s=30) + ax.set_xlabel("Token Count") + ax.set_ylabel("Validation Loss") + ax.set_title(f"{arch} | {int(epoch)} Ep") + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_xlim(x_min, x_max) + ax.set_ylim(epoch_ylims[epoch]) + ax.grid(alpha=0.3) + + # Create legend for budget colors + handles = [ + plt.Line2D([0], [0], color=budget_colors[b], marker="o", linestyle="-", label=f"{b:.1e}") for b in budgets + ] + fig.legend(handles, [f"{b:.1e}" for b in budgets], title="Budget", loc="center left", bbox_to_anchor=(1, 0.5)) + + plt.tight_layout() + save_figure(fig, output_path) + + +def visualize_loss_by_param_count(df, metric="eval_loss", output_path=f"{RESULT_PATH}/plantcad_loss_by_params.png"): + """Plot loss vs params, colored by budget, faceted by architecture (cols) and epochs (rows).""" + required_cols = [metric, "params", "architecture", "flops_budget", "epochs"] + df_clean = filter_to_finished_runs(df).dropna(subset=required_cols) + + if df_clean.empty: + logger.warning(f"No finished runs with required columns {required_cols}. Skipping visualization.") + return + + architectures = sorted(df_clean["architecture"].unique()) + budgets = sorted(df_clean["flops_budget"].unique()) + unique_epochs = sorted(df_clean["epochs"].unique()) + + # Create budget colormap + cmap = plt.get_cmap("viridis") + budget_colors = {b: cmap(i / max(1, len(budgets) - 1)) for i, b in enumerate(budgets)} + + # Get global x-limits and per-epoch y-limits + x_min, x_max = df_clean["params"].min(), df_clean["params"].max() + epoch_ylims = {} + for epoch in unique_epochs: + epoch_data = df_clean[df_clean["epochs"] == epoch] + y_min, y_max = epoch_data[metric].min(), epoch_data[metric].max() + y_padding = (y_max - y_min) * 0.1 + epoch_ylims[epoch] = (y_min - y_padding, y_max + y_padding) + + fig, axes = plt.subplots( + len(unique_epochs), len(architectures), figsize=(5 * len(architectures), 3 * len(unique_epochs)), squeeze=False + ) + + for ei, epoch in enumerate(unique_epochs): + for ai, arch in enumerate(architectures): + ax = axes[ei, ai] + for budget in budgets: + data = df_clean[ + (df_clean["architecture"] == arch) + & (df_clean["flops_budget"] == budget) + & (df_clean["epochs"] == epoch) + ].sort_values("params") + if data.empty: + continue + color = budget_colors[budget] + ax.plot(data["params"], data[metric], alpha=0.7, linewidth=1.5, color=color) + ax.scatter(data["params"], data[metric], alpha=0.8, color=color, s=30) + ax.set_xlabel("Param Count") + ax.set_ylabel("Validation Loss") + ax.set_title(f"{arch} | {int(epoch)} Ep") + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_xlim(x_min, x_max) + ax.set_ylim(epoch_ylims[epoch]) + ax.grid(alpha=0.3) + + # Create legend for budget colors + handles = [ + plt.Line2D([0], [0], color=budget_colors[b], marker="o", linestyle="-", label=f"{b:.1e}") for b in budgets + ] + fig.legend(handles, [f"{b:.1e}" for b in budgets], title="Budget", loc="center left", bbox_to_anchor=(1, 0.5)) + + plt.tight_layout() + save_figure(fig, output_path) + + +def get_size_label(rank: int) -> str: + """Map model size rank to human-readable label.""" + labels = ["XXS", "XS", "S", "M", "L", "XL", "XXL", "XXXL"] + if rank < len(labels): + return labels[rank] + return f"Size-{rank}" + + +def visualize_loss_by_epochs( + df, + metric: str = "eval_loss", + output_path: str = f"{RESULT_PATH}/plantcad_loss_by_epochs.png", +) -> None: + """Plot normalized loss vs epochs, faceted by architecture and budget.""" + required_cols = [metric, "tokens", "params", "architecture", "flops_budget", "epochs"] + df_clean = filter_to_finished_runs(df).dropna(subset=required_cols).copy() + + architectures = sorted(df_clean["architecture"].unique()) + budgets = sorted(df_clean["flops_budget"].unique(), reverse=True) + if not architectures or not budgets: + logger.warning("No data to visualize.") + return + + # Normalize loss to 0-1 per (arch, budget, tokens, params) group + group_cols = ["architecture", "flops_budget", "tokens", "params"] + df_clean["loss_norm"] = df_clean.groupby(group_cols)[metric].transform( + lambda x: (x - x.min()) / (x.max() - x.min()) if x.max() > x.min() else 0.5 + ) + + # Create model size rank based on params (within each budget) + df_clean["size_rank"] = df_clean.groupby("flops_budget")["params"].transform( + lambda x: x.rank(method="dense").astype(int) - 1 + ) + unique_ranks = sorted(df_clean["size_rank"].unique()) + n_ranks = len(unique_ranks) + cmap = plt.get_cmap("tab10" if n_ranks <= 10 else "tab20") + rank_colors = {r: cmap(r / max(1, n_ranks - 1)) if n_ranks > 1 else cmap(0) for r in unique_ranks} + + fig, axes = plt.subplots( + len(budgets), len(architectures), figsize=(5 * len(architectures), 2 * len(budgets)), squeeze=False + ) + + for bi, budget in enumerate(budgets): + for ai, arch in enumerate(architectures): + ax = axes[bi, ai] + df_facet = df_clean[(df_clean["architecture"] == arch) & (df_clean["flops_budget"] == budget)] + if df_facet.empty: + ax.set_visible(False) + continue + + # Get unique (tokens, params, size_rank) combos for this facet + combos = ( + df_facet.groupby(["tokens", "params", "size_rank"]) + .size() + .reset_index()[["tokens", "params", "size_rank"]] + ) + combos = combos.sort_values(["size_rank", "tokens"]) + + for _, row in combos.iterrows(): + tokens, params, size_rank = row["tokens"], row["params"], row["size_rank"] + data = df_facet[(df_facet["tokens"] == tokens) & (df_facet["params"] == params)].sort_values("epochs") + if data.empty: + continue + color = rank_colors[size_rank] + ax.plot(data["epochs"], data["loss_norm"], color=color, alpha=0.7, linewidth=1.5) + ax.scatter(data["epochs"], data["loss_norm"], color=color, s=30, zorder=5) + + ax.set_xlabel("Epochs") + ax.set_ylabel("Normalized Loss (0-1)") + ax.set_title(f"{arch} | C={budget:.1e}") + ax.set_xscale("log", base=2) + ax.set_xlim(df_clean["epochs"].min(), df_clean["epochs"].max()) + ax.set_ylim(-0.05, 1.05) + ax.grid(alpha=0.3) + + # Create legend for model size ranks + handles = [plt.Line2D([0], [0], color=rank_colors[r], marker="o", linestyle="-") for r in unique_ranks] + labels = [get_size_label(r) for r in unique_ranks] + fig.legend(handles, labels, title="Model Size", loc="center left", bbox_to_anchor=(1, 0.5)) + + plt.tight_layout() + save_figure(fig, output_path) + + +def visualize_loss_by_param_and_epoch_count( + df, + architecture: str = DEFAULT_ARCH, + metric: str = "eval_loss", + clip_percentile: float = 80.0, + output_path: str = f"{RESULT_PATH}/plantcad_loss_contour.png", +) -> None: + """2D contour plot of loss vs params (y) and epochs (x), faceted by flops budget.""" + required_cols = [metric, "params", "epochs", "architecture", "flops_budget"] + df_clean = filter_to_finished_runs(df[df["architecture"] == architecture]).dropna(subset=required_cols).copy() + + if df_clean.empty: + logger.warning(f"No data for architecture '{architecture}'") + return + + n_unique_epochs = df_clean["epochs"].nunique() + if n_unique_epochs < 2: + logger.warning(f"Cannot create contour plot: need at least 2 unique epoch values, but got {n_unique_epochs}") + return + + df_clean["log_loss"] = np.log2(df_clean[metric]) + df_clean["log_loss"] = df_clean["log_loss"].clip(upper=df_clean["log_loss"].quantile(clip_percentile / 100)) + + budgets = sorted(df_clean["flops_budget"].unique()) + n_budgets = len(budgets) + + # Compute global color scale across all budgets + global_min = df_clean["log_loss"].min() + global_max = df_clean["log_loss"].max() + levels = np.linspace(global_min, global_max, 50) + + fig, axes = plt.subplots(1, n_budgets, figsize=(2.5 * n_budgets, 3.4), squeeze=False) + axes = axes[0] + + contour = None + + for idx, budget in enumerate(budgets): + ax = axes[idx] + df_budget = df_clean[df_clean["flops_budget"] == budget] + + # Interpolate scattered points to a finer grid in log space + x_data = np.log(df_budget["epochs"].values) + y_data = np.log(df_budget["params"].values) + z_data = df_budget["log_loss"].values + + xi = np.geomspace(df_budget["epochs"].min(), df_budget["epochs"].max(), 200) + yi = np.geomspace(df_budget["params"].min(), df_budget["params"].max(), 200) + xi_grid, yi_grid = np.meshgrid(xi, yi) + zi = griddata((x_data, y_data), z_data, (np.log(xi_grid), np.log(yi_grid)), method="cubic") + + contour = ax.contourf(xi_grid, yi_grid, zi, levels=levels, cmap="viridis", antialiased=True, extend="both") + + # Only show raw data points for the lowest compute budget + if idx == 0: + ax.scatter(df_budget["epochs"], df_budget["params"], c="black", s=22, alpha=0.15, edgecolors="none") + + ax.set_xscale("log", base=2) + ax.set_yscale("log") + ax.margins(0) + ax.set_ylabel("Params" if idx == 0 else "") + ax.set_title(f"C = {budget:.1e}") + if idx > 0: + ax.set_yticklabels([]) + + # Add colorbar in dedicated axes on the right, with top/bottom margin for labels + fig.subplots_adjust(right=0.92, top=0.82, bottom=0.15) + + # Single shared x-axis label + fig.supxlabel("Epochs") + fig.suptitle(f"Validation Loss ({architecture})", fontsize=14) + cbar_ax = fig.add_axes([0.94, 0.15, 0.02, 0.7]) + cbar = fig.colorbar(contour, cax=cbar_ax, label="log₂(Loss)") + cbar.ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{x:.3f}")) + + save_figure(fig, output_path) + + +# ------------------------------------------------------------ +# Isoflop scaling law analysis (from exp2101_plantcad_isoflop_fit.py) +# ------------------------------------------------------------ + + +def fit_quadratic_optimum(x_vals, loss_vals): + """ + Fits L = a*(ln x)^2 + b*(ln x) + c and returns optimal x. + Returns (x_opt, fit_coeffs). + Raises ValueError if the fit is concave (no minimum). + """ + log_x = np.log(x_vals) + coeffs = np.polyfit(log_x, loss_vals, 2) # [a, b, c] + a, b, _ = coeffs + + # if a <= 0: + # raise ValueError(f"Concave fit detected: {coeffs}") + + # Minimum at ln x = -b / (2a) + ln_x_opt = -b / (2 * a) + x_opt = np.exp(ln_x_opt) + return x_opt, coeffs + + +def analyze_budgets(df): + """ + Analyzes each budget group to find optimal N and D using independent quadratic fits. + Returns a DataFrame with columns: budget, opt_N, opt_D, coeffs_N, coeffs_D, group_data. + """ + budgets = sorted(df["flops_budget"].unique()) + results = [] + + for budget in budgets: + group = df[df["flops_budget"] == budget].sort_values("params") + if len(group) < 3: + print(f"Skipping budget {budget}: has fewer than 3 points ({len(group)}), cannot fit.") + continue + + N = group["params"].values + D = group["tokens"].values + L = group["eval_loss"].values + + # Fit independent quadratics + opt_N, coeffs_N = fit_quadratic_optimum(N, L) + opt_D, coeffs_D = fit_quadratic_optimum(D, L) + + results.append( + { + "budget": budget, + "opt_N": opt_N, + "opt_D": opt_D, + "coeffs_N": coeffs_N, + "coeffs_D": coeffs_D, + "group_data": group, + } + ) + + return pd.DataFrame(results) + + +def fit_scaling_law(budgets, optimal_vals): + """ + Fits log(optimal_val) = m * log(budget) + c. + Returns (m, c, B_smooth, V_smooth) where B_smooth and V_smooth are smoothed + budget and predicted value arrays for plotting. + """ + log_B = np.log(budgets) + log_V = np.log(optimal_vals) + + m, c = np.polyfit(log_B, log_V, 1) + + # Generate smooth line + B_smooth = np.logspace(np.log10(budgets.min()), np.log10(budgets.max()), 100) + V_smooth = np.exp(m * np.log(B_smooth) + c) + + return m, c, B_smooth, V_smooth + + +def plot_isoflop_curves(ax_N, ax_D, analysis_results, colors_N, colors_D): + """Plots the top row: Loss vs Params/Tokens for each budget.""" + for idx, row in analysis_results.iterrows(): + budget = row["budget"] + + # Determine color based on index + # Assuming colors_N and colors_D are arrays matching analysis_results length + color_n = colors_N[idx] + color_d = colors_D[idx] + + group = row["group_data"] + + # Plot Data - N (Blues) + ax_N.scatter(group["params"], group["eval_loss"], color=color_n, alpha=0.7, s=20, label=f"{budget:.1e}") + + # Plot Data - D (Greens) + ax_D.scatter(group["tokens"], group["eval_loss"], color=color_d, alpha=0.7, s=20, label=f"{budget:.1e}") + + # Plot Fits (only between min and max x-values in the data) + N_range = np.logspace(np.log10(group["params"].min()), np.log10(group["params"].max()), 100) + D_range = np.logspace(np.log10(group["tokens"].min()), np.log10(group["tokens"].max()), 100) + + L_pred_N = np.polyval(row["coeffs_N"], np.log(N_range)) + L_pred_D = np.polyval(row["coeffs_D"], np.log(D_range)) + + # Lines use the same color scale + ax_N.plot(N_range, L_pred_N, color=color_n, linestyle="--", alpha=0.5) + ax_D.plot(D_range, L_pred_D, color=color_d, linestyle="--", alpha=0.5) + + # Plot Optima + L_min_N = np.polyval(row["coeffs_N"], np.log(row["opt_N"])) + ax_N.scatter([row["opt_N"]], [L_min_N], color=color_n, marker="s", s=100, edgecolors="black", zorder=10) + + L_min_D = np.polyval(row["coeffs_D"], np.log(row["opt_D"])) + ax_D.scatter([row["opt_D"]], [L_min_D], color=color_d, marker="s", s=100, edgecolors="black", zorder=10) + + # Configure axes + for ax, xlabel, title in [(ax_N, "Parameters (N)", "Loss vs Parameters"), (ax_D, "Tokens (D)", "Loss vs Tokens")]: + ax.set_xscale("log") + ax.set_xlabel(xlabel) + ax.set_ylabel("Validation Loss") + ax.set_title(title) + ax.grid(True, which="both", ls="-", alpha=0.2, axis="x") + # Don't show legend here anymore, global legend instead + + +def plot_scaling_laws(ax_N, ax_D, analysis_results, colors_N, colors_D): + """Plots the bottom row: Optimal Params/Tokens vs FLOPs.""" + valid_data = analysis_results.dropna(subset=["opt_N", "opt_D"]) + valid_indices = valid_data.index + + # Extract colors corresponding to valid data points + valid_colors_N = colors_N[valid_indices] + valid_colors_D = colors_D[valid_indices] + + budgets = valid_data["budget"].values + opt_N = valid_data["opt_N"].values + opt_D = valid_data["opt_D"].values + + # Fit Scaling Laws + m_N, c_N, B_smooth_N, N_smooth = fit_scaling_law(budgets, opt_N) + m_D, c_D, B_smooth_D, D_smooth = fit_scaling_law(budgets, opt_D) + + # Plot Params Scaling + ax_N.scatter(budgets, opt_N, color=valid_colors_N, marker="s", s=100, edgecolors="black", zorder=5) + ax_N.plot(B_smooth_N, N_smooth, color="gray", linestyle="--", label=f"$N^* \\propto C^{{{m_N:.3f}}}$") + + # Plot Tokens Scaling + ax_D.scatter(budgets, opt_D, color=valid_colors_D, marker="s", s=100, edgecolors="black", zorder=5) + ax_D.plot(B_smooth_D, D_smooth, color="gray", linestyle="--", label=f"$D^* \\propto C^{{{m_D:.3f}}}$") + + # Configure axes + for ax, ylabel, title in [ + (ax_N, "Optimal Parameters (N*)", "Optimal Parameters vs Compute"), + (ax_D, "Optimal Tokens (D*)", "Optimal Tokens vs Compute"), + ]: + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_xlabel("Compute Budget (FLOPs)") + ax.set_ylabel(ylabel) + ax.set_title(title) + ax.grid(True, which="both", ls="-", alpha=0.2, axis="x") + ax.legend() + # Set explicit tick positions at budget values only + ax.set_xticks(budgets) + ax.set_xticklabels([f"{b:.1e}" for b in budgets]) + ax.minorticks_off() + + # Calculate Ratio Function + ratio_coeff = np.exp(c_D - c_N) + ratio_exp_diff = m_D - m_N # Simple difference of exponents + # Optimal ratio exponent: R* = (m_N - m_D) / (m_N + m_D) + # This is the Chinchilla optimal token/param ratio exponent + ratio_exp_opt = (m_N - m_D) / (m_N + m_D) + + ratio_str = ( + f"Optimal Ratio: $\\frac{{D_{{opt}}}}{{N_{{opt}}}} = " + f"{ratio_exp_opt:.4f}$ ($\\frac{{D^*}}{{N^*}} = {ratio_exp_diff:.4f}$)" + ) + + # Return all computed values for summary display + return { + "ratio_str": ratio_str, + "m_N": m_N, + "c_N": c_N, + "m_D": m_D, + "c_D": c_D, + "ratio_coeff": ratio_coeff, + "ratio_exp_diff": ratio_exp_diff, + "ratio_exp_opt": ratio_exp_opt, + } + + +class DualColorMarker: + """ + Custom marker handler for legend that draws two markers side-by-side. + """ + + def __init__(self, color1, color2): + self.color1 = color1 + self.color2 = color2 + + def legend_artist(self, legend, orig_handle, fontsize, handlebox): + x0, y0 = handlebox.xdescent, handlebox.ydescent + width, height = handlebox.width, handlebox.height + + # Increase marker size (taller and wider) + rect_h = height * 1.4 # Significantly larger + rect_w = rect_h * 1.2 # Aspect ratio ~ 1.2 (wide) + + # Centered vertically + y_pos = y0 + (height - rect_h) / 2 + + # Gap between markers + marker_gap = 5 # Fixed pixel gap + + # Total width of the dual marker group + total_group_width = rect_w + marker_gap + rect_w + + # Align the dual marker group to the RIGHT side of the handlebox area + # This pushes it closer to the text label which starts immediately after handlebox + start_x = x0 + width - total_group_width + + # Draw two markers side-by-side + # Left marker (Params/Blue) + p1 = plt.Rectangle( + [start_x, y_pos], + rect_w, + rect_h, + facecolor=self.color1, + edgecolor="black", + transform=handlebox.get_transform(), + ) + # Right marker (Tokens/Green) + p2 = plt.Rectangle( + [start_x + rect_w + marker_gap, y_pos], + rect_w, + rect_h, + facecolor=self.color2, + edgecolor="black", + transform=handlebox.get_transform(), + ) + + handlebox.add_artist(p1) + handlebox.add_artist(p2) + return [p1, p2] + + +def plot_scaling_extrapolation(analysis_results, scaling_results, out_dir): + """Creates a figure showing N*, D*, and D*/N* vs compute with extrapolation.""" + valid_data = analysis_results.dropna(subset=["opt_N", "opt_D"]) + budgets = valid_data["budget"].values + opt_N, opt_D = valid_data["opt_N"].values, valid_data["opt_D"].values + + m_N, c_N = scaling_results["m_N"], scaling_results["c_N"] + m_D, c_D = scaling_results["m_D"], scaling_results["c_D"] + ratio_exp_diff = scaling_results["ratio_exp_diff"] # Simple: m_D - m_N + ratio_exp_opt = scaling_results["ratio_exp_opt"] # R*: (m_N - m_D) / (m_N + m_D) + + # Extrapolate from min observed to 1e22 FLOPs + C_ext = np.logspace(np.log10(budgets.min()), 22, 200) + N_ext = np.exp(c_N) * C_ext**m_N + D_ext = np.exp(c_D) * C_ext**m_D + # Use R* (optimal exponent) for the ratio line + # D*/N* = ratio_coeff * C^(-R*) since R* = (m_N - m_D)/(m_N + m_D) and we want D*/N* + ratio_coeff = scaling_results["ratio_coeff"] + ratio_ext = ratio_coeff * C_ext ** (-ratio_exp_opt) + + _, ax1 = plt.subplots(figsize=(10, 4.8)) + ax1.set_xscale("log") + ax1.set_yscale("log") + + # Plot fits and data + (ln,) = ax1.plot(C_ext, N_ext, color="tab:blue", lw=2, label=f"N* ∝ C^{m_N:.3f}") + (ld,) = ax1.plot(C_ext, D_ext, color="tab:green", lw=2, label=f"D* ∝ C^{m_D:.3f}") + ax1.scatter(budgets, opt_N, color="tab:blue", s=80, edgecolors="black", zorder=5, marker="o") + ax1.scatter(budgets, opt_D, color="tab:green", s=80, edgecolors="black", zorder=5, marker="s") + + # Extrapolation shading + ax1.axvspan(budgets.max(), 1e22, alpha=0.1, color="gray") + ax1.axvline(budgets.max(), color="gray", ls=":", alpha=0.5) + ax1.set_xlabel("Compute Budget C (FLOPs)", fontsize=12) + ax1.set_ylabel("Optimal N* (params) / D* (tokens)", fontsize=12) + ax1.grid(True, which="major", ls="-", alpha=0.2) + + # Right axis: Ratio (set behind ax1 so annotations render on top) + ax2 = ax1.twinx() + ax2.set_yscale("log") + ax2.set_zorder(ax1.get_zorder() - 1) + ax1.patch.set_visible(False) + # Label shows both simple and normalized exponents + (lr,) = ax2.plot( + C_ext, + ratio_ext, + color="gray", + lw=2, + ls="--", + label=f"D*/N* (R*={ratio_exp_opt:.3f}, Δm={ratio_exp_diff:.3f})", + ) + ax2.set_ylabel("Optimal Ratio D*/N*", fontsize=12) + + # Combined legend + ax1.legend(handles=[ln, ld, lr], loc="upper left") + + # Reference annotations + for C_ref in [1e18, 1e20, 1e22]: + N_ref, D_ref = np.exp(c_N) * C_ref**m_N, np.exp(c_D) * C_ref**m_D + ax1.axvline(C_ref, color="gray", ls=":", alpha=0.3) + ax1.annotate( + f"C={C_ref:.0e}\nN*={N_ref:.1e}\nD*={D_ref:.1e}\nD*/N*={D_ref/N_ref:.0f}", + xy=(C_ref, N_ref * 1.5), + fontsize=8, + ha="center", + va="bottom", + zorder=20, + bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8, edgecolor="gray"), + ) + + plt.title("Scaling Law Extrapolation: Optimal Compute Allocation", fontsize=14) + plt.tight_layout() + + out_png, out_pdf = out_dir / "plantcad_scaling_extrapolation.png", out_dir / "plantcad_scaling_extrapolation.pdf" + plt.savefig(out_png, dpi=300, bbox_inches="tight") + plt.savefig(out_pdf, dpi=300, bbox_inches="tight") + plt.close() + return out_png, out_pdf + + +def print_summary(df, analysis_results, scaling_results, out_files): + """Print summary using rich tables.""" + console.print() + + # Optimal allocations table + opt_table = Table(title="Optimal Allocations") + opt_table.add_column("Budget (C)", justify="right", style="cyan") + opt_table.add_column("Opt N*", justify="right") + opt_table.add_column("Opt D*", justify="right") + opt_table.add_column("D*/N*", justify="right", style="yellow") + + for _, row in analysis_results.iterrows(): + ratio = row["opt_D"] / row["opt_N"] if row["opt_N"] > 0 else np.nan + opt_table.add_row( + f"{row['budget']:.1e}", + f"{row['opt_N']:.2e}", + f"{row['opt_D']:.2e}", + f"{ratio:.1f}", + ) + console.print(opt_table) + console.print() + + # Scaling laws + m_N, c_N = scaling_results["m_N"], scaling_results["c_N"] + m_D, c_D = scaling_results["m_D"], scaling_results["c_D"] + ratio_coeff = scaling_results["ratio_coeff"] + ratio_exp_diff = scaling_results["ratio_exp_diff"] + ratio_exp_opt = scaling_results["ratio_exp_opt"] + + console.print("[bold]Scaling Laws:[/bold]") + console.print(f" N* ∝ C^{m_N:.3f} [dim](log N* = {m_N:.4f} log C + {c_N:.4f})[/dim]") + console.print(f" D* ∝ C^{m_D:.3f} [dim](log D* = {m_D:.4f} log C + {c_D:.4f})[/dim]") + console.print(f" D*/N* = {ratio_coeff:.4e} · C^{ratio_exp_diff:.4f}") + console.print(f" R* = (m_N - m_D)/(m_N + m_D) = {ratio_exp_opt:.4f}") + console.print() + + for f in out_files: + console.print(f"[green]✓[/green] Saved: {f}") + console.print() + + +def run_isoflop_fit_analysis(df: pd.DataFrame, architecture: str = DEFAULT_ARCH) -> None: + """Run isoflop scaling law fitting and visualization on epoch=1 data for the given architecture.""" + df_fit = filter_to_finished_runs(df) + df_fit = df_fit[(df_fit["architecture"] == architecture) & (df_fit["epochs"] == 1)].dropna( + subset=["eval_loss", "tokens", "params", "flops_budget"] + ) + if df_fit.empty: + raise ValueError("No valid data found for isoflop fitting") + analysis_results = analyze_budgets(df_fit) + + # Setup Figure: 2x2 Grid + fig, axes = plt.subplots(2, 2, figsize=(12, 8), gridspec_kw={"height_ratios": [1.5, 1]}) + + # Colorscales + colors_N = plt.cm.Blues(np.linspace(0.4, 1.0, len(analysis_results))) + colors_D = plt.cm.Greens(np.linspace(0.4, 1.0, len(analysis_results))) + + # Top Row + plot_isoflop_curves(axes[0, 0], axes[0, 1], analysis_results, colors_N, colors_D) + + # Bottom Row + scaling_results = plot_scaling_laws(axes[1, 0], axes[1, 1], analysis_results, colors_N, colors_D) + + # Global Legend with dual-color markers + legend_handles = [] + for _, row in analysis_results.iterrows(): + legend_handles.append(plt.Rectangle((0, 0), 1, 1, color="none", label=f"{row['budget']:.1e}")) + + handler_map = {legend_handles[i]: DualColorMarker(colors_N[i], colors_D[i]) for i in range(len(legend_handles))} + + leg = fig.legend( + handles=legend_handles, + handler_map=handler_map, + title="Compute Budget [$C$]\n(FLOPs)", + loc="center left", + bbox_to_anchor=(0.86, 0.5), + borderaxespad=0.5, + handlelength=2.5, + labelspacing=0.6, + handletextpad=0.8, + ) + leg.get_title().set_multialignment("center") + + plt.suptitle("IsoFLOP Analysis: Quadratic Optima & Scaling Laws", y=0.94, fontsize=16) + plt.figtext(0.5, 0.88, scaling_results["ratio_str"], ha="center", fontsize=14) + plt.tight_layout(rect=[0, 0, 0.85, 0.91]) + + out_dir = Path(RESULT_PATH) + out_dir.mkdir(parents=True, exist_ok=True) + out_png = out_dir / "plantcad_isoflop_fits.png" + out_pdf = out_dir / "plantcad_isoflop_fits.pdf" + plt.savefig(out_png, dpi=300, bbox_inches="tight") + plt.savefig(out_pdf, dpi=300, bbox_inches="tight") + plt.close() + + # Create extrapolation figure + extrap_png, extrap_pdf = plot_scaling_extrapolation(analysis_results, scaling_results, out_dir) + + # Print summary + out_files = [out_png, out_pdf, extrap_png, extrap_pdf] + print_summary(df_fit, analysis_results, scaling_results, out_files) + + +# ------------------------------------------------------------ + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Fetch and analyze plantcad isoflop runs") + parser.add_argument( + "--force", + action="store_true", + help="Force refetch from W&B even if CSV exists", + ) + parser.add_argument( + "--output", + default=f"{RESULT_PATH}/plantcad_isoflops.csv", + help=f"Output CSV path (default: {RESULT_PATH}/plantcad_isoflops.csv)", + ) + parser.add_argument( + "--show-wandb-runs", + action="store_true", + help="Log detailed info for first 2 W&B runs", + ) + args = parser.parse_args() + + # Setup logging to console and file + log_path = Path(RESULT_PATH) / "plantcad_isoflop_analysis.txt" + setup_logging(log_path) + + output_path = Path(args.output) + + # Check if CSV exists and load from it unless --force is specified + if output_path.exists() and not args.force: + logger.info(f"Loading existing data from {output_path}") + df = pd.read_csv(output_path) + logger.info(f"Loaded {len(df)} runs from CSV") + else: + logger.info("Fetching runs from W&B...") + df = fetch_plantcad_runs(show_wandb_runs=args.show_wandb_runs) + save_runs(df, output_path) + + df = filter_exploded_runs(df) + validate_runs(df) + summarize_runs(df) + visualize_loss_by_token_count(df) + visualize_loss_by_param_count(df) + visualize_loss_by_epochs(df) + visualize_loss_by_param_and_epoch_count(df) + run_isoflop_fit_analysis(df) + + # Append rich console output to log file + with open(log_path, "a") as f: + f.write("\n" + console.export_text()) + f.write(f"\nAnalysis complete. Logs saved to {log_path}\n") + # Also print to console + console.print(f"[green]Analysis complete.[/green] Logs saved to {log_path}") diff --git a/experiments/plantcad/exp2101_plantcad_isotoken_analysis.py b/experiments/plantcad/exp2101_plantcad_isotoken_analysis.py new file mode 100644 index 0000000000..e0bc05c13e --- /dev/null +++ b/experiments/plantcad/exp2101_plantcad_isotoken_analysis.py @@ -0,0 +1,408 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Analyze isotoken sweep runs: plot eval_loss vs tokens, grouped by params, faceted by token_fraction.""" + +import argparse +import logging +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import wandb +from rich.console import Console +from rich.table import Table + +logger = logging.getLogger(__name__) + +RUN_PREFIX = "plantcad_width_sweep_v1.0" +RESULT_PATH = "experiments/plantcad/results/isotoken_sweep" + +# Plotly colors +PLOTLY_RED = "#EF553B" +PLOTLY_GREEN = "#00CC96" + + +def fit_changepoint_model(log_x: np.ndarray, y: np.ndarray) -> dict | None: + """Fit a continuous piecewise linear model with one change point in log-space. + + The model is continuous at the changepoint: + - For x <= cp: y = y_cp + slope1 * (x - cp) + - For x > cp: y = y_cp + slope2 * (x - cp) + + Returns dict with: changepoint, slope1, slope2, y_cp, rss + """ + n = len(log_x) + if n < 4: + # Not enough points for a meaningful changepoint model + return None + + best_rss = np.inf + best_result = None + + # Sort by log_x to ensure proper ordering + order = np.argsort(log_x) + log_x_sorted = log_x[order] + y_sorted = y[order] + + # Try each possible changepoint (need at least 2 points on each side) + for cp_idx in range(2, n - 1): + cp_log = log_x_sorted[cp_idx] + + # Build design matrix for continuous piecewise linear model: + # y = y_cp + slope1 * (x - cp) for x <= cp + # y = y_cp + slope2 * (x - cp) for x > cp + # Rewrite as: y = y_cp + slope1 * (x - cp) * I(x <= cp) + slope2 * (x - cp) * I(x > cp) + # Parameters: [y_cp, slope1, slope2] + + x_centered = log_x_sorted - cp_log + left_mask = log_x_sorted <= cp_log + right_mask = ~left_mask + + # Design matrix: [1, (x-cp)*I_left, (x-cp)*I_right] + X = np.column_stack( + [ + np.ones(n), + x_centered * left_mask, + x_centered * right_mask, + ] + ) + + # Least squares fit + coeffs, _residuals, _rank, _s = np.linalg.lstsq(X, y_sorted, rcond=None) + y_cp, slope1, slope2 = coeffs + + # Compute predictions and RSS + y_pred = X @ coeffs + rss = np.sum((y_sorted - y_pred) ** 2) + + if rss < best_rss: + best_rss = rss + best_result = { + "changepoint_log": cp_log, + "changepoint": 10**cp_log, + "slope1": slope1, + "slope2": slope2, + "y_cp": y_cp, + "rss": rss, + "cp_idx": cp_idx, + } + + return best_result + + +def fetch_runs() -> pd.DataFrame: + """Fetch isotoken sweep runs from W&B.""" + api = wandb.Api() + runs = api.runs("marin", filters={"display_name": {"$regex": f"^{RUN_PREFIX}"}}) + + data = [] + run_names = [] + for run in runs: + if run.state != "finished": + continue + tags = {k: v for t in run.tags if "=" in t for k, v in [t.split("=", 1)]} + cfg = run.config + run_names.append(run.name) + data.append( + { + "eval_loss": run.summary.get("eval/plantcad2/loss"), + "tokens": float(tags.get("tokens", 0)), + "params": float(tags.get("params", 0)), + "token_fraction": float(tags.get("token_fraction", 0)), + "hidden_size": int(tags.get("hidden_size", 0)), + "batch_size": cfg.get("trainer", {}).get("train_batch_size"), + "lr": cfg.get("optimizer", {}).get("learning_rate"), + "beta2": cfg.get("optimizer", {}).get("beta2"), + "num_layers": cfg.get("model", {}).get("num_layers"), + "num_heads": cfg.get("model", {}).get("num_heads"), + } + ) + + df = pd.DataFrame(data, index=run_names) + + # Log rows with missing essential fields before dropping + essential = ["eval_loss", "tokens", "params", "token_fraction"] + missing_mask = df[essential].isna().any(axis=1) + for run_name in df[missing_mask].index: + row = df.loc[run_name] + missing = [col for col in essential if pd.isna(row[col])] + logger.warning(f"Dropping run '{run_name}': missing {missing}") + + return df.dropna(subset=essential).reset_index(drop=True) + + +def format_count(value: float, suffix: str = "") -> str: + """Format large numbers with M/B suffix.""" + if value >= 1e9: + return f"{value/1e9:.1f}B{suffix}" + return f"{value/1e6:.1f}M{suffix}" + + +def format_power_of_2_fraction(frac: float) -> str: + """Format a fraction as 1/2^n (e.g., 0.25 -> '1/4').""" + if frac == 1.0: + return "1" + denom = round(1 / frac) + return f"1/{denom}" + + +def build_fit_annotation(slope: float, rss: float, cp_result: dict | None, cp_label_fmt: callable) -> str: + """Build annotation text for log-linear and changepoint model fits.""" + lines = ["Log-linear model:", f" slope = {slope:.4f}", f" RSS = {rss:.2e}"] + if cp_result is not None: + lines.extend( + [ + "", + "Changepoint model:", + f" slope₁ = {cp_result['slope1']:.4f}", + f" slope₂ = {cp_result['slope2']:.4f}", + f" changepoint = {cp_label_fmt(cp_result['changepoint'])}", + f" RSS = {cp_result['rss']:.2e}", + ] + ) + return "\n".join(lines) + + +def _fmt_int(val) -> str: + return str(int(val)) if pd.notna(val) else "-" + + +def _fmt_float(val, fmt: str = ".4f") -> str: + return f"{val:{fmt}}" if pd.notna(val) else "-" + + +def print_run_table(df: pd.DataFrame) -> None: + """Print a rich table summarizing experiment configurations.""" + table = Table(title="Isotoken Sweep Experiments") + + table.add_column("Params", justify="right") + table.add_column("Data", justify="center") + table.add_column("Tokens", justify="right") + table.add_column("Hidden", justify="right") + table.add_column("Layers", justify="right") + table.add_column("Heads", justify="right") + table.add_column("Batch", justify="right") + table.add_column("LR", justify="right") + table.add_column("β₂", justify="right") + table.add_column("Loss", justify="right") + + # Sort by params then by token_fraction descending + sorted_df = df.sort_values(["params", "token_fraction"], ascending=[True, False]) + + for _, row in sorted_df.iterrows(): + table.add_row( + format_count(row["params"]), + format_power_of_2_fraction(row["token_fraction"]), + format_count(row["tokens"]), + _fmt_int(row["hidden_size"]), + _fmt_int(row.get("num_layers")), + _fmt_int(row.get("num_heads")), + _fmt_int(row.get("batch_size")), + _fmt_float(row.get("lr")), + _fmt_float(row.get("beta2")), + _fmt_float(row["eval_loss"]), + ) + + Console().print(table) + + +def plot_fits(ax, x: np.ndarray, y: np.ndarray, cp_label_fmt: callable) -> None: + """Plot log-linear and changepoint model fits with annotations.""" + log_x = np.log10(x) + + # Log-linear fit + slope, intercept = np.polyfit(log_x, y, 1) + rss_linear = np.sum((y - (slope * log_x + intercept)) ** 2) + + # Changepoint fit + cp_result = fit_changepoint_model(log_x, y) + + # Plot fit lines + x_fit = np.logspace(log_x.min(), log_x.max(), 100) + ax.plot(x_fit, slope * np.log10(x_fit) + intercept, color="gray", ls="--", lw=1.5, alpha=0.6, label="Log-linear") + + if cp_result is not None: + cp_log, y_cp = cp_result["changepoint_log"], cp_result["y_cp"] + for x_seg, mask in [(x_fit[np.log10(x_fit) <= cp_log], True), (x_fit[np.log10(x_fit) >= cp_log], False)]: + slope_seg = cp_result["slope1"] if mask else cp_result["slope2"] + label = "Changepoint" if mask else None + ax.plot( + x_seg, + y_cp + slope_seg * (np.log10(x_seg) - cp_log), + color=PLOTLY_GREEN, + ls="-", + lw=2, + alpha=0.9, + label=label, + ) + ax.axvline(cp_result["changepoint"], color=PLOTLY_GREEN, ls=":", alpha=0.5, lw=1) + + # Annotation + ax.text( + 0.98, + 0.98, + build_fit_annotation(slope, rss_linear, cp_result, cp_label_fmt), + transform=ax.transAxes, + fontsize=9, + fontfamily="monospace", + va="top", + ha="right", + bbox=dict(boxstyle="round,pad=0.4", facecolor="white", edgecolor="gray", alpha=0.9), + ) + + +def save_figure(fig, output_path: str) -> None: + """Save figure as PNG and PDF.""" + out = Path(output_path) + out.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(out, dpi=300, bbox_inches="tight") + fig.savefig(out.with_suffix(".pdf"), dpi=300, bbox_inches="tight") + print(f"Saved: {out} and {out.with_suffix('.pdf')}") + + +def plot_loss_by_params(df: pd.DataFrame, output_path: str, max_data_fraction: float = 0.5) -> None: + """Plot eval_loss vs params, faceted by token_fraction, with log-linear and changepoint fits.""" + df = df[df["token_fraction"] <= max_data_fraction] + + fractions = sorted(df["token_fraction"].unique()) + params_list = sorted(df["params"].unique()) + + # Map params to sizes (small range: 30-80) + sizes = np.linspace(30, 80, len(params_list)) + size_map = {p: s for p, s in zip(params_list, sizes, strict=True)} + + # Shared x-axis limits with padding in log space + xlim = (params_list[0] * 0.8, params_list[-1] * 1.2) + + n_cols = 3 + n_rows = (len(fractions) + n_cols - 1) // n_cols + fig, axes = plt.subplots(n_rows, n_cols, figsize=(4.5 * n_cols, 4 * n_rows), squeeze=False) + + for i, frac in enumerate(fractions): + ax = axes[divmod(i, n_cols)] + data = df[df["token_fraction"] == frac].sort_values("params") + x, y = data["params"].values, data["eval_loss"].values + + ax.plot(x, y, color="gray", alpha=0.4, lw=1) + for _, row in data.iterrows(): + ax.scatter(row["params"], row["eval_loss"], color="C0", s=size_map[row["params"]], zorder=5) + + plot_fits(ax, x, y, cp_label_fmt=lambda v: format_count(v)) + + total_tokens = data["tokens"].iloc[0] + ax.set_xlabel("Params") + ax.set_ylabel("Eval Loss" if i % n_cols == 0 else "") + ax.set_title(f"Data: {frac:.0%} = {format_power_of_2_fraction(frac)} ({format_count(total_tokens, ' tokens')})") + ax.set_xscale("log") + ax.set_xlim(xlim) + ax.grid(alpha=0.3) + ax.legend(loc="lower left", fontsize=8) + + for i in range(len(fractions), n_rows * n_cols): + axes[divmod(i, n_cols)].set_visible(False) + + # Size-based legend + handles = [ + plt.Line2D([0], [0], marker="o", color="w", markerfacecolor="C0", markersize=np.sqrt(size_map[p])) + for p in params_list + ] + fig.legend( + handles, [format_count(p) for p in params_list], title="Params", loc="center left", bbox_to_anchor=(1, 0.5) + ) + + plt.tight_layout() + save_figure(fig, output_path) + + +def plot_loss_by_tokens(df: pd.DataFrame, output_path: str, max_data_fraction: float = 0.5) -> None: + """Plot eval_loss vs tokens, faceted by params, with log-linear and changepoint fits.""" + df = df[df["token_fraction"] <= max_data_fraction] + params_list = sorted(df["params"].unique()) + fractions = sorted(df["token_fraction"].unique()) + + # Map fractions to sizes (small range: 30-80) + sizes = np.linspace(30, 80, len(fractions)) + size_map = {f: s for f, s in zip(fractions, sizes, strict=True)} + + n_cols = 3 + n_rows = (len(params_list) + n_cols - 1) // n_cols + fig, axes = plt.subplots(n_rows, n_cols, figsize=(4.5 * n_cols, 4 * n_rows), squeeze=False) + + for i, params in enumerate(params_list): + ax = axes[divmod(i, n_cols)] + data = df[df["params"] == params].sort_values("tokens") + x, y = data["tokens"].values, data["eval_loss"].values + + ax.plot(x, y, color="gray", alpha=0.4, lw=1) + for _, row in data.iterrows(): + ax.scatter(row["tokens"], row["eval_loss"], color="C0", s=size_map[row["token_fraction"]], zorder=5) + + plot_fits(ax, x, y, cp_label_fmt=lambda v: format_count(v, " tok")) + + ax.set_xlabel("Tokens") + ax.set_ylabel("Eval Loss" if i % n_cols == 0 else "") + ax.set_title(f"Params: {format_count(params)}") + ax.set_xscale("log") + ax.grid(alpha=0.3) + ax.legend(loc="lower left", fontsize=8) + + for i in range(len(params_list), n_rows * n_cols): + axes[divmod(i, n_cols)].set_visible(False) + + # Size-based legend + handles = [ + plt.Line2D([0], [0], marker="o", color="w", markerfacecolor="C0", markersize=np.sqrt(size_map[f])) + for f in fractions + ] + fig.legend( + handles, [f"{f:.0%}" for f in fractions], title="Data Fraction", loc="center left", bbox_to_anchor=(1, 0.5) + ) + + plt.tight_layout() + save_figure(fig, output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--output-params", default=f"{RESULT_PATH}/isotoken_sweep_loss_by_params.png") + parser.add_argument("--output-tokens", default=f"{RESULT_PATH}/isotoken_sweep_loss_by_tokens.png") + parser.add_argument("--csv", default=f"{RESULT_PATH}/isotoken_sweep_runs.csv") + parser.add_argument("--force", action="store_true", help="Force refetch from W&B") + args = parser.parse_args() + + csv_path = Path(args.csv) + if csv_path.exists() and not args.force: + print(f"Loading from {csv_path}") + df = pd.read_csv(csv_path) + else: + print("Fetching from W&B...") + df = fetch_runs() + csv_path.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(csv_path, index=False) + print(f"Saved {len(df)} runs to {csv_path}") + + # Ensure numeric types for proper sorting + for col in ["tokens", "params", "token_fraction"]: + df[col] = pd.to_numeric(df[col]) + + print(f"Loaded {len(df)} runs") + + if df.empty: + print("No runs found. Check RUN_PREFIX or use --force to refetch.") + else: + print_run_table(df) + plot_loss_by_params(df, args.output_params) + plot_loss_by_tokens(df, args.output_tokens) diff --git a/experiments/plantcad/exp2101_plantcad_isotoken_sweep.py b/experiments/plantcad/exp2101_plantcad_isotoken_sweep.py new file mode 100644 index 0000000000..02d50593f3 --- /dev/null +++ b/experiments/plantcad/exp2101_plantcad_isotoken_sweep.py @@ -0,0 +1,199 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Iso-token sweep: vary hidden_size/width across fixed token fractions (1/16, 1/8, 1/4 of a dataset).""" + +import logging +import math +from dataclasses import dataclass + +from levanter.models.qwen import Qwen3Config +from levanter.optim.cautious import CautiousConfig + +from experiments.defaults import default_train, _prepare_data_config +from experiments.llama import compute_num_parameters +from experiments.simple_train_config import SimpleTrainConfig +from marin.execution.executor import ExecutorStep, executor_main +from fray.cluster import ResourceConfig + +from experiments.plantcad.exp2101_plantcad_isoflop_sweep import ( + prepare_plantcad_dataset, + tokenize_plantcad_dataset, + IsoFlopTokenizeConfig, + IsoFlopDataConfig, + pick_v5p_type, + format_num, +) + +logger = logging.getLogger("ray") + +# Sweep parameters +HIDDEN_SIZES = [256, 384, 512, 640, 768, 896, 1024, 1280, 1536, 1792, 2048] +TOKEN_FRACTIONS = [1.0 / 2] +TRAIN_STEPS = 8192 +MLP_RATIO = 4 +HIDDEN_HEAD_RATIO = 128 +LR_CONSTANT = 0.33 +LR_MAX = 0.02 + + +@dataclass(frozen=True) +class WidthSweepConfig: + tokenized_dataset: ExecutorStep + vocab_size: int + seq_len: int + dataset_tokens: int + experiment_name: str = "plantcad_width_sweep_v1.0" + + +def round_to_power_of_two(x: float) -> int: + return max(1, 2 ** round(math.log2(x))) + + +def generate_width_sweep_steps(cfg: WidthSweepConfig) -> list[ExecutorStep]: + """Generate training steps for width sweep.""" + steps: list[ExecutorStep] = [] + rows: list[dict] = [] + + for hidden_size in HIDDEN_SIZES: + for token_fraction in TOKEN_FRACTIONS: + target_tokens = int(cfg.dataset_tokens * token_fraction) + batch_size = round_to_power_of_two(target_tokens / (TRAIN_STEPS * cfg.seq_len)) + + # LR scales with sqrt(batch) / hidden + lr = min(LR_MAX, (LR_CONSTANT * math.sqrt(batch_size)) / hidden_size) + beta2 = 0.98 ** (batch_size / 128) + + # Derive architecture from hidden size + intermediate_dim = hidden_size * MLP_RATIO + n_heads = max(1, hidden_size // HIDDEN_HEAD_RATIO) + hs_pow = math.log2(hidden_size) + num_layers = max(1, round(hidden_size / (64 + (hs_pow * 4) - 8))) + + model_cfg = Qwen3Config( + max_seq_len=cfg.seq_len, + hidden_dim=hidden_size, + intermediate_dim=intermediate_dim, + num_heads=n_heads, + num_kv_heads=n_heads, + num_layers=num_layers, + ) + + tpu_type = pick_v5p_type(model_cfg, hidden_size, num_layers, batch_size, cfg.seq_len, cfg.vocab_size) + num_params = compute_num_parameters(model_cfg, cfg.vocab_size) + actual_tokens = TRAIN_STEPS * batch_size * cfg.seq_len + + optimizer_cfg = CautiousConfig( + learning_rate=lr, + weight_decay=0.1, + min_lr_ratio=0.0, + warmup=0.1, + beta1=0.95, + beta2=beta2, + epsilon=1e-15, + max_grad_norm=1, + adamc_weight_decay=True, + lr_schedule="linear", + decay=0.2, + ) + + train_cfg = SimpleTrainConfig( + resources=ResourceConfig.with_tpu(tpu_type), + train_batch_size=batch_size, + num_train_steps=TRAIN_STEPS, + learning_rate=lr, + weight_decay=0.1, + min_lr_ratio=0.0, + lr_schedule="linear", + decay=0.2, + steps_per_eval=TRAIN_STEPS // 2, + per_device_eval_parallelism=512, + max_eval_batches=64, + optimizer_config=optimizer_cfg, + ) + + pretraining_data = _prepare_data_config(cfg.tokenized_dataset, use_default_validation=False) + + frac_label = f"{int(1/token_fraction)}x" + step = default_train( + name=f"{cfg.experiment_name}-H{hidden_size}-F{frac_label}-P{format_num(num_params)}", + tokenized=pretraining_data, + model_config=model_cfg, + train_config=train_cfg, + tags=( + f"hidden_size={hidden_size}", + f"token_fraction={token_fraction}", + f"params={num_params}", + f"tokens={actual_tokens}", + f"batch_size={batch_size}", + f"tpu={tpu_type}", + ), + use_default_validation=False, + eval_harness_tasks=[], + ) + steps.append(step) + + rows.append( + { + "hidden": hidden_size, + "layers": num_layers, + "heads": n_heads, + "frac": f"{token_fraction:.2f}", + "batch": batch_size, + "lr": f"{lr:.2e}", + "beta2": f"{beta2:.3f}", + "params": format_num(num_params), + "target_tok": format_num(target_tokens), + "actual_tok": format_num(actual_tokens), + "tpu": tpu_type, + } + ) + + # Print config table + if rows: + headers = list(rows[0].keys()) + col_widths = {h: max(len(h), max(len(str(r[h])) for r in rows)) for h in headers} + header_line = " | ".join(h.rjust(col_widths[h]) for h in headers) + sep_line = "-+-".join("-" * col_widths[h] for h in headers) + print(f"\n{header_line}\n{sep_line}") + for r in rows: + print(" | ".join(str(r[h]).rjust(col_widths[h]) for h in headers)) + print() + + logger.info(f"Generated {len(steps)} width sweep configurations") + return steps + + +def main(): + plantcad_prepared = prepare_plantcad_dataset() + plantcad_tokenized = tokenize_plantcad_dataset(prepared=plantcad_prepared) + + # Effective dataset tokens after cropping + data_cfg: IsoFlopDataConfig = plantcad_prepared.config + tok_cfg: IsoFlopTokenizeConfig = plantcad_tokenized.config + dataset_tokens = int(data_cfg.total_token_count * (data_cfg.output_seq_len / data_cfg.input_seq_len)) + + sweep_cfg = WidthSweepConfig( + tokenized_dataset=plantcad_tokenized, + vocab_size=tok_cfg.vocab_size, + seq_len=data_cfg.output_seq_len, + dataset_tokens=dataset_tokens, + ) + + sweep_steps = generate_width_sweep_steps(sweep_cfg) + executor_main(steps=[plantcad_prepared, plantcad_tokenized, *sweep_steps]) + + +if __name__ == "__main__": + main() diff --git a/experiments/plantcad/exp2101_plantcad_meta_analysis.py b/experiments/plantcad/exp2101_plantcad_meta_analysis.py new file mode 100644 index 0000000000..8083a760be --- /dev/null +++ b/experiments/plantcad/exp2101_plantcad_meta_analysis.py @@ -0,0 +1,819 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: RUF001 RUF002 RUF003 +# ↑ Allow Greek letters (α, β) and math symbols (×) in strings/comments for readability. + +""" +Meta-analysis of isoflop scaling experiments across multiple compute ranges. + +Aggregates results from individual isoflop sweeps and fits scaling laws to understand +how optimal model size and training tokens scale with compute budget. Supports two +fitting strategies: + +- **Parabolic**: Fits a log-quadratic curve to each isoflop slice independently, + finding the minimum via the parabola vertex. Simple but treats each budget in isolation. + +- **Parametric** (Chinchilla Approach 3): Fits the full loss surface + L(N, D) = E + A/N^α + B/D^β across all points for a dataset, then derives optimal + N* and D* from the constraint C = k·N·D. More principled but requires more data. + +Generates a 2x2 figure: +- Top row: Loss vs Parameters, Loss vs Tokens (raw data + fitted curves per budget) +- Bottom row: Optimal N* vs FLOPs, Optimal D* vs FLOPs (power-law scaling fits) + +Configure DATASET, FIT_STRATEGY, and filter ranges at the top of the file. +""" + +import logging +from collections.abc import Iterator +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from scipy.optimize import least_squares + +# Reuse existing validation and filtering functions +from experiments.plantcad.exp2101_plantcad_isoflop_analysis import ( + EXPLODED_BUDGETS, + EXPLODED_RUNS, + filter_to_finished_runs, + fit_scaling_law, +) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Define compute ranges with their metadata +# Format: (version, display_name, steps) +DATASETS = { + # PlantCAD compute ranges (v2.x): + # Note: v2.3 (minimal-compute, 2K steps) excluded from analysis + "plantcad": [ + ("v2.6", "very-low (4K steps)", 4096), + ("v2.4", "low (8K steps)", 8192), + ("v2.5", "mid (16K steps)", 16384), + ("v2.2", "high (32K steps)", 32768), + ], + # Text (DCLM baseline) compute ranges: + "dclm": [ + ("v2.12", "low (16K steps)", 16384), + ("v2.13", "mid (32K steps)", 32768), + ("v2.9", "high (64K steps)", 65536), + ], +} + +DATASET = "dclm" +COMPUTE_RANGES = DATASETS[DATASET] + +RESULTS_BASE_PATH = Path("experiments/plantcad/results") +OUTPUT_DIR = RESULTS_BASE_PATH / "meta_analysis" +OUTPUT_BASE = "plantcad_isoflop_meta_analysis" +DEFAULT_ARCH = "qwen" +EXPORT_DPI = 300 + +# Optional filters for parabola fitting (None = no filter) +# All data is plotted; these only control what's used for fitting +PARAMS_RANGE: tuple[float | None, float | None] = (None, None) # (min, max) +TOKENS_RANGE: tuple[float | None, float | None] = (None, None) # (min, max) +BATCH_SIZE_RANGE: tuple[int | None, int | None] = (None, None) # (min, max), e.g. (8, 64) + +# Fit strategy: "parabolic" or "parametric" +FIT_STRATEGY = "parametric" + +# Color scheme for compute ranges (distinct colors for each range) +RANGE_COLORS = { + # PlantCAD versions + "v2.3": "#1f77b4", # Blue - minimal + "v2.6": "#ff7f0e", # Orange - very-low + "v2.4": "#2ca02c", # Green - low + "v2.5": "#d62728", # Red - mid + "v2.2": "#9467bd", # Purple - high + # Text (DCLM baseline) versions + "v2.12": "#2ca02c", # Green - low + "v2.13": "#d62728", # Red - mid + "v2.9": "#9467bd", # Purple - high +} + + +def load_all_csvs() -> dict[str, pd.DataFrame]: + """Load all plantcad_isoflops.csv files for v2.x versions.""" + data: dict[str, pd.DataFrame] = {} + for version, name, steps in COMPUTE_RANGES: + csv_path = RESULTS_BASE_PATH / version / "plantcad_isoflops.csv" + if csv_path.exists(): + df = pd.read_csv(csv_path) + df["version"] = version + df["compute_range"] = name + df["range_steps"] = steps + data[version] = df + logger.info(f"Loaded {len(df)} runs from {csv_path}") + else: + logger.warning(f"CSV not found: {csv_path}") + return data + + +def filter_exploded_runs_for_version(df: pd.DataFrame, version: str) -> pd.DataFrame: + """Filter out runs where training exploded for a specific version.""" + # EXPLODED_BUDGETS/RUNS use version keys without "v" prefix (e.g., "2.2" not "v2.2") + version_key = version.lstrip("v") + exploded_runs = EXPLODED_RUNS.get(version_key, []) + exploded_budgets = EXPLODED_BUDGETS.get(version_key, []) + + run_mask = df["run_name"].isin(exploded_runs) + budget_mask = df["flops_budget"].isin(exploded_budgets) + + n_filtered = run_mask.sum() + budget_mask.sum() + if n_filtered > 0: + logger.warning(f"Filtered {n_filtered} exploded runs/budgets for {version}") + + return df[~run_mask & ~budget_mask] + + +def fit_log_quadratic(x_vals: np.ndarray, loss_vals: np.ndarray) -> tuple[float, np.ndarray]: + """Fits L = a*(ln x)² + b*(ln x) + c and returns the optimal x.""" + log_x = np.log(x_vals) + coeffs = np.polyfit(log_x, loss_vals, 2) # [a, b, c] + a, b, _ = coeffs + ln_x_opt = -b / (2 * a) + return np.exp(ln_x_opt), coeffs + + +def fit_parabolic(df: pd.DataFrame) -> pd.DataFrame: + """Fit each curve_id separately with np.polyfit.""" + rows = [] + for curve_id, group in df.groupby("curve_id"): + if len(group) < 3: + continue + opt_N, coeffs_N = fit_log_quadratic(group["params"].values, group["eval_loss"].values) + opt_D, coeffs_D = fit_log_quadratic(group["tokens"].values, group["eval_loss"].values) + rows.append( + { + "curve_id": curve_id, + "version": group["version"].iloc[0], + "budget": group["flops_budget"].iloc[0], + "opt_N": opt_N, + "opt_D": opt_D, + "coeffs_N": coeffs_N, + "coeffs_D": coeffs_D, + } + ) + return pd.DataFrame(rows) + + +def _compute_huber_scale(L: np.ndarray, scale_factor: float = 2.0) -> float: + """Compute a data-driven f_scale for Huber loss. + + Uses the median absolute deviation (MAD) of the loss values, which is + robust to outliers. The scale_factor (default 2.0) determines how many + MADs a residual can be before being treated as an outlier. + + Args: + L: Array of loss values. + scale_factor: Multiplier for MAD (default 2.0 means residuals > 2*MAD are outliers). + + Returns: + f_scale value for scipy.optimize.least_squares with loss='huber'. + """ + # MAD = median(|L - median(L)|) + # For normal data, std ≈ 1.4826 * MAD + mad = np.median(np.abs(L - np.median(L))) + return scale_factor * mad + + +def _grid_search_alpha_beta( + N: np.ndarray, D: np.ndarray, L: np.ndarray, alpha_range: np.ndarray, beta_range: np.ndarray +) -> tuple[float, float, float, float, float]: + """Grid search over (α, β) with linear least squares for (E, A, B). + + For fixed α, β, the model L = E + A/N^α + B/D^β is linear in (E, A, B), + so we can solve it efficiently with linear least squares. + + Returns best (E, A, B, α, β). + """ + from scipy.optimize import nnls + + best_cost = np.inf + best_params = (0.0, 0.0, 0.0, 0.0, 0.0) + + for alpha in alpha_range: + for beta in beta_range: + # Design matrix: L = E*1 + A*(1/N^α) + B*(1/D^β) + X = np.column_stack([np.ones_like(N), 1.0 / N**alpha, 1.0 / D**beta]) + + # Non-negative least squares to enforce E, A, B >= 0 + coeffs, cost = nnls(X, L) + E, A, B = coeffs + + if cost < best_cost: + best_cost = cost + best_params = (E, A, B, alpha, beta) + + return best_params + + +def _huber_cost(residuals: np.ndarray, f_scale: float) -> float: + """Compute Huber cost matching scipy.optimize.least_squares convention.""" + z = (residuals / f_scale) ** 2 + rho = np.where(z <= 1, z, 2 * np.sqrt(z) - 1) + return 0.5 * np.sum(rho) * f_scale**2 + + +def _check_optimization_progress( + x0: list[float], + result, + residuals_fn, + param_names: list[str], + context: str, + f_scale: float, + min_rel_improvement: float = 0.0, +) -> None: + """Log parameter changes and verify the optimization improved the objective. + + Args: + x0: Initial parameter values. + result: Result from scipy.optimize.least_squares. + residuals_fn: The residuals function used in optimization. + param_names: Names for each parameter (for logging). + context: Description string for error messages (e.g., version name). + f_scale: Scale parameter for Huber loss. + min_rel_improvement: Minimum relative reduction in cost (e.g., 1e-4 = 0.01%). + + Raises: + RuntimeError: If cost did not decrease by at least min_rel_improvement. + """ + initial_cost = _huber_cost(residuals_fn(x0), f_scale) + final_cost = _huber_cost(result.fun, f_scale) + rel_improvement = (initial_cost - final_cost) / (initial_cost + 1e-12) + + # Log parameter changes + changes = [f"{name}: {x0[i]:.4g} → {result.x[i]:.4g}" for i, name in enumerate(param_names)] + logger.info(f" Parameters: {', '.join(changes)}") + logger.info(f" Huber cost: {initial_cost:.6g} → {final_cost:.6g} ({rel_improvement:.2%} reduction)") + + if rel_improvement < min_rel_improvement: + raise RuntimeError( + f"Optimization did not improve objective for {context}: " + f"cost {initial_cost:.6g} → {final_cost:.6g} ({rel_improvement:.2%} < {min_rel_improvement:.2%} required)" + ) + + +def _check_bound_violations( + params: dict[str, float], + bounds: dict[str, tuple[float, float]], + context: str, + rel_tol: float = 0.01, +) -> None: + """Check if any parameters are hitting their bounds. + + Args: + params: Dictionary of parameter name to value. + bounds: Dictionary of parameter name to (lower, upper) bounds. + context: Description string for error messages. + rel_tol: Relative tolerance for detecting bound proximity. + + Raises: + RuntimeError: If any parameter is near its bounds. + """ + violations = [] + for name, value in params.items(): + if name not in bounds: + continue + lo, hi = bounds[name] + if np.isfinite(lo) and value < lo + rel_tol * max(abs(lo), 1): + violations.append(f"{name}={value:.4g} near lower bound {lo}") + if np.isfinite(hi) and value > hi - rel_tol * max(abs(hi), 1): + violations.append(f"{name}={value:.4g} near upper bound {hi}") + + if violations: + raise RuntimeError( + f"Parameters hitting bounds for {context}: {', '.join(violations)}. " + "Consider relaxing bounds or investigating data quality." + ) + + +def fit_parametric(df: pd.DataFrame) -> pd.DataFrame: + """Fit Chinchilla Approach 3: parametric loss model per experiment. + + Model: L(N, D) = E + A/N^α + B/D^β + + Uses a two-stage approach: + 1. Grid search over (α, β) with linear least squares for (E, A, B) + 2. Refine with nonlinear optimization starting from grid search result + + Then computes per-curve optimal N* and D* using the 6ND compute assumption. + + Returns DataFrame with per-curve results compatible with parabolic output format. + """ + rows = [] + + # Grid for α, β search + alpha_range = np.linspace(0.05, 1.0, 512) + beta_range = np.linspace(0.05, 1.0, 512) + + for version, version_group in df.groupby("version"): + N = version_group["params"].values + D = version_group["tokens"].values + L = version_group["eval_loss"].values + + # Log data diagnostics + mad = np.median(np.abs(L - np.median(L))) + logger.info( + f"Fitting {version}: n={len(L)}, loss=[{L.min():.4f}, {L.max():.4f}], " f"std={L.std():.4f}, MAD={mad:.4f}" + ) + + # Stage 1: Grid search for good initial values + E_init, A_init, B_init, alpha_init, beta_init = _grid_search_alpha_beta(N, D, L, alpha_range, beta_range) + logger.info( + f" Grid search: E={E_init:.4f}, A={A_init:.2f}, B={B_init:.2f}, " f"α={alpha_init:.4f}, β={beta_init:.4f}" + ) + + # Check grid search didn't hit boundaries + _check_bound_violations( + {"α_grid": alpha_init, "β_grid": beta_init}, + {"α_grid": (alpha_range[0], alpha_range[-1]), "β_grid": (beta_range[0], beta_range[-1])}, + f"{version} grid search", + ) + + # Stage 2: Refine with nonlinear optimization + def residuals(params, N=N, D=D, L=L): + E, log_A, log_B, alpha, beta = params + A = np.exp(log_A) + B = np.exp(log_B) + L_pred = E + A / N**alpha + B / D**beta + return L_pred - L + + # Use grid search result as initial guess (convert A, B to log space) + x0 = [E_init, np.log(max(A_init, 1e-10)), np.log(max(B_init, 1e-10)), alpha_init, beta_init] + + # Compute data-driven f_scale for Huber loss + f_scale = _compute_huber_scale(L) + logger.info(f" Huber f_scale={f_scale:.6f} (2×MAD)") + + # Parameter bounds: [E, log_A, log_B, alpha, beta] + E_bounds = (0.0, np.inf) + alpha_bounds = (0.01, 2.0) + beta_bounds = (0.01, 2.0) + lower_bounds = [E_bounds[0], -np.inf, -np.inf, alpha_bounds[0], beta_bounds[0]] + upper_bounds = [E_bounds[1], np.inf, np.inf, alpha_bounds[1], beta_bounds[1]] + + # Fit with Huber loss (robust to outliers) + result = least_squares( + residuals, + x0, + loss="huber", + f_scale=f_scale, + bounds=(lower_bounds, upper_bounds), + ) + + # Check convergence diagnostics + if not result.success: + raise RuntimeError( + f"Optimization did not converge for {version}: " f"status={result.status}, message='{result.message}'" + ) + logger.info( + f" Optimization converged: status={result.status}, " + f"nfev={result.nfev}, optimality={result.optimality:.2e}" + ) + + # Verify optimization actually improved and log changes + _check_optimization_progress(x0, result, residuals, ["E", "log_A", "log_B", "α", "β"], version, f_scale) + + E, log_A, log_B, alpha, beta = result.x + A = np.exp(log_A) + B = np.exp(log_B) + + # Check if any bounded parameters are hitting their bounds + _check_bound_violations( + {"E": E, "α": alpha, "β": beta}, + {"E": E_bounds, "α": alpha_bounds, "β": beta_bounds}, + version, + ) + + # Compute fit residuals for diagnostics + L_pred = E + A / N**alpha + B / D**beta + residuals_final = L - L_pred + rmse = np.sqrt(np.mean(residuals_final**2)) + mae = np.mean(np.abs(residuals_final)) + + logger.info(f" Refined fit: E={E:.4f}, A={A:.2f}, B={B:.2f}, " f"α={alpha:.4f}, β={beta:.4f}") + logger.info(f" Fit quality: RMSE={rmse:.6f}, MAE={mae:.6f}, cost={result.cost:.6f}") + + # Compute scaling exponents from C = k*N*D assumption + # a = β/(α+β), b = α/(α+β) + a = beta / (alpha + beta) + b = alpha / (alpha + beta) + logger.info(f" Scaling exponents: a={a:.4f} (N* ∝ C^a), b={b:.4f} (D* ∝ C^b)") + + # Compute per-curve optimal N* and D* + # N* = (αA/βB)^(1/(α+β)) * (C/k)^(β/(α+β)) + # D* = (βB/αA)^(1/(α+β)) * (C/k)^(α/(α+β)) + coeff_N = (alpha * A / (beta * B)) ** (1 / (alpha + beta)) + coeff_D = (beta * B / (alpha * A)) ** (1 / (alpha + beta)) + + # Log per-curve k values + k_values = [] + for curve_id, curve_group in version_group.groupby("curve_id"): + if len(curve_group) < 3: + continue + + C = curve_group["flops_budget"].iloc[0] + N_curve = curve_group["params"].values + D_curve = curve_group["tokens"].values + + # Compute k via log-space regression: log(C) = log(k) + log(N) + log(D) + # => k = C / geometric_mean(N * D) + log_ND = np.log(N_curve * D_curve) + k = C / np.exp(log_ND.mean()) + k_values.append(k) + + opt_N = coeff_N * (C / k) ** a + opt_D = coeff_D * (C / k) ** b + + # Store parametric params as coeffs for plotting + # Format: (E, A, B, alpha, beta, C, k) + parametric_coeffs = (E, A, B, alpha, beta, C, k) + + rows.append( + { + "curve_id": curve_id, + "version": version, + "budget": C, + "opt_N": opt_N, + "opt_D": opt_D, + "coeffs_N": parametric_coeffs, + "coeffs_D": parametric_coeffs, + } + ) + + # Log k statistics for this version + if k_values: + k_arr = np.array(k_values) + logger.info( + f" FLOPs coefficient k: mean={k_arr.mean():.2f}, " + f"range=[{k_arr.min():.2f}, {k_arr.max():.2f}] (C = k·N·D)" + ) + + return pd.DataFrame(rows) + + +def fit_curves(df: pd.DataFrame) -> pd.DataFrame: + """Route to appropriate fitting function based on FIT_STRATEGY.""" + if FIT_STRATEGY == "parabolic": + return fit_parabolic(df) + elif FIT_STRATEGY == "parametric": + return fit_parametric(df) + else: + raise ValueError(f"Unknown FIT_STRATEGY: {FIT_STRATEGY}") + + +def iter_filtered_data( + data: dict[str, pd.DataFrame], +) -> Iterator[tuple[str, str, int, pd.DataFrame]]: + """Iterate over compute ranges, yielding filtered dataframes. + + Applies standard filtering (exploded runs, finished runs) to each version's data. + + Yields: + Tuples of (version, display_name, steps, filtered_df). + """ + for version, name, steps in COMPUTE_RANGES: + if version not in data: + continue + df = data[version].copy() + df = filter_exploded_runs_for_version(df, version) + df = filter_to_finished_runs(df) + yield version, name, steps, df + + +def print_budget_summary(data: dict[str, pd.DataFrame]) -> None: + """Print a sorted list of compute range, version, and flop budgets.""" + rows: list[tuple[str, str, float]] = [] + + for version, name, _steps, df in iter_filtered_data(data): + for budget in sorted(df["flops_budget"].unique()): + rows.append((name, version, budget)) + + # Sort by compute range name (which includes step count), then by budget + rows.sort(key=lambda x: (x[0], x[2])) + + logger.info("\n" + "=" * 60) + logger.info("Compute Range / Version / FLOPs Budget Summary") + logger.info("=" * 60) + logger.info(f"{'Compute Range':<20} {'Version':<10} {'FLOPs Budget':<15}") + logger.info("-" * 60) + for name, version, budget in rows: + logger.info(f"{name:<20} {version:<10} {budget:.2e}") + logger.info("=" * 60 + "\n") + + +def apply_fit_filters(df: pd.DataFrame) -> pd.DataFrame: + """Apply PARAMS_RANGE, TOKENS_RANGE, and BATCH_SIZE_RANGE filters for fitting.""" + filtered = df.copy() + if PARAMS_RANGE[0] is not None: + filtered = filtered[filtered["params"] >= PARAMS_RANGE[0]] + if PARAMS_RANGE[1] is not None: + filtered = filtered[filtered["params"] <= PARAMS_RANGE[1]] + if TOKENS_RANGE[0] is not None: + filtered = filtered[filtered["tokens"] >= TOKENS_RANGE[0]] + if TOKENS_RANGE[1] is not None: + filtered = filtered[filtered["tokens"] <= TOKENS_RANGE[1]] + if BATCH_SIZE_RANGE[0] is not None: + filtered = filtered[filtered["batch_size"] >= BATCH_SIZE_RANGE[0]] + if BATCH_SIZE_RANGE[1] is not None: + filtered = filtered[filtered["batch_size"] <= BATCH_SIZE_RANGE[1]] + return filtered + + +def prepare_combined_data(data: dict[str, pd.DataFrame], architecture: str = DEFAULT_ARCH) -> pd.DataFrame: + """Combine all versions into one DataFrame, filter, add curve_id and budget_index.""" + dfs = [] + for _version, name, _steps, df in iter_filtered_data(data): + df = df[(df["architecture"] == architecture) & (df["epochs"] == 1)].dropna( + subset=["eval_loss", "tokens", "params", "flops_budget"] + ) + if df.empty: + logger.warning(f"No valid data for {name}") + continue + # Compute budget_index within each version (0 = smallest budget, 1 = next, etc.) + sorted_budgets = sorted(df["flops_budget"].unique()) + budget_to_index = {b: i for i, b in enumerate(sorted_budgets)} + df = df.copy() + df["budget_index"] = df["flops_budget"].map(budget_to_index) + dfs.append(df) + if not dfs: + return pd.DataFrame() + combined = pd.concat(dfs, ignore_index=True) + combined["curve_id"] = combined["version"] + "_" + combined["flops_budget"].astype(str) + return combined + + +def prepare_data( + data: dict[str, pd.DataFrame], architecture: str = DEFAULT_ARCH +) -> dict[str, tuple[pd.DataFrame, pd.DataFrame]]: + """Prepare data for plotting. Returns dict[version -> (data_df, analysis_df)].""" + combined = prepare_combined_data(data, architecture) + if combined.empty: + return {} + + filtered = apply_fit_filters(combined) + analysis = fit_curves(filtered) + + # Split back per-version, adding group_data for plotting compatibility + results: dict[str, tuple[pd.DataFrame, pd.DataFrame]] = {} + for version in combined["version"].unique(): + version_data = combined[combined["version"] == version] + version_analysis = analysis[analysis["version"] == version].copy() + version_analysis["group_data"] = version_analysis["curve_id"].apply( + lambda cid: combined[combined["curve_id"] == cid] + ) + results[version] = (version_data, version_analysis) + return results + + +def plot_loss_vs_variable( + ax, + range_data: dict[str, tuple[pd.DataFrame, pd.DataFrame]], + variable: str, +): + """Plot Loss vs a variable (params or tokens) with all compute ranges overlaid. + + Args: + ax: Matplotlib axes to plot on. + range_data: Dictionary mapping version to (dataframe, analysis) tuples. + variable: Either "params" or "tokens". + """ + if variable == "params": + col_name = "params" + coeffs_key = "coeffs_N" + opt_key = "opt_N" + xlabel = "Parameters (N)" + title = "Loss vs Parameters" + fit_range = PARAMS_RANGE + else: + col_name = "tokens" + coeffs_key = "coeffs_D" + opt_key = "opt_D" + xlabel = "Tokens (D)" + title = "Loss vs Tokens" + fit_range = TOKENS_RANGE + + for version, name, _ in COMPUTE_RANGES: + if version not in range_data: + continue + + df, analysis = range_data[version] + color = RANGE_COLORS[version] + + # Plot all data points for this range (more transparent) + ax.scatter(df[col_name], df["eval_loss"], color=color, alpha=0.25, s=15, label=name) + + # Plot fits and optimum points for each budget + for _, row in analysis.iterrows(): + group = row["group_data"] + val_min, val_max = group[col_name].min(), group[col_name].max() + + # Draw fit curve within actual data range + x_range = np.logspace(np.log10(val_min), np.log10(val_max), 100) + + if FIT_STRATEGY == "parabolic": + # Parabolic: L = a*(ln x)² + b*(ln x) + c + L_pred = np.polyval(row[coeffs_key], np.log(x_range)) + elif FIT_STRATEGY == "parametric": + # Parametric: L(N, D) = E + A/N^α + B/D^β with constraint D = C/(kN) + # k is computed per-curve from actual data + E, A, B, alpha, beta, C, k = row[coeffs_key] + if variable == "params": + # L(N) with D = C/(kN) + L_pred = E + A / x_range**alpha + B * (k * x_range / C) ** beta + else: + # L(D) with N = C/(kD) + L_pred = E + A * (k * x_range / C) ** alpha + B / x_range**beta + else: + raise ValueError(f"Unknown FIT_STRATEGY: {FIT_STRATEGY}") + + ax.plot(x_range, L_pred, color=color, linestyle="--", alpha=0.5, linewidth=1) + + # Plot optimum point + if pd.notna(row[opt_key]): + if FIT_STRATEGY == "parabolic": + L_min = np.polyval(row[coeffs_key], np.log(row[opt_key])) + elif FIT_STRATEGY == "parametric": + E, A, B, alpha, beta, C, k = row[coeffs_key] + if variable == "params": + L_min = E + A / row[opt_key] ** alpha + B * (k * row[opt_key] / C) ** beta + else: + L_min = E + A * (k * row[opt_key] / C) ** alpha + B / row[opt_key] ** beta + ax.scatter([row[opt_key]], [L_min], color=color, marker="s", s=60, edgecolors="black", zorder=10) + + ax.set_xscale("log") + ax.set_xlabel(xlabel) + ax.set_ylabel("Validation Loss") + ax.set_title(title) + ax.grid(True, which="both", ls="-", alpha=0.2) + ax.legend(fontsize=8, loc="upper right", framealpha=0.5) + + # Draw vertical lines at filter boundaries + if fit_range[0] is not None: + ax.axvline(fit_range[0], color="gray", linestyle=":", linewidth=1.5, alpha=0.7) + if fit_range[1] is not None: + ax.axvline(fit_range[1], color="gray", linestyle=":", linewidth=1.5, alpha=0.7) + + +def plot_optimal_params_vs_flops(ax, range_data: dict[str, tuple[pd.DataFrame, pd.DataFrame]]): + """Plot Optimal Parameters vs Compute with scaling laws per range.""" + for version, name, _ in COMPUTE_RANGES: + if version not in range_data: + continue + + _, analysis = range_data[version] + color = RANGE_COLORS[version] + + valid_data = analysis.dropna(subset=["opt_N", "opt_D"]) + if len(valid_data) < 2: + continue + + budgets = valid_data["budget"].values + opt_N = valid_data["opt_N"].values + + # Fit and plot scaling law + m_N, _c_N, B_smooth, N_smooth = fit_scaling_law(budgets, opt_N) + + ax.scatter(budgets, opt_N, color=color, marker="s", s=60, edgecolors="black", zorder=5) + ax.plot( + B_smooth, N_smooth, color=color, linestyle="--", alpha=0.7, label=f"{name}: $N^* \\propto C^{{{m_N:.2f}}}$" + ) + + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_xlabel("Compute Budget (FLOPs)") + ax.set_ylabel("Optimal Parameters (N*)") + ax.set_title("Optimal Parameters vs Compute") + ax.grid(True, which="both", ls="-", alpha=0.2) + ax.legend(fontsize=7, loc="upper left", framealpha=0.5) + + +def plot_optimal_tokens_vs_flops(ax, range_data: dict[str, tuple[pd.DataFrame, pd.DataFrame]]): + """Plot Optimal Tokens vs Compute with scaling laws per range.""" + for version, name, _ in COMPUTE_RANGES: + if version not in range_data: + continue + + _, analysis = range_data[version] + color = RANGE_COLORS[version] + + valid_data = analysis.dropna(subset=["opt_N", "opt_D"]) + if len(valid_data) < 2: + continue + + budgets = valid_data["budget"].values + opt_D = valid_data["opt_D"].values + + # Fit and plot scaling law + m_D, _c_D, B_smooth, D_smooth = fit_scaling_law(budgets, opt_D) + + ax.scatter(budgets, opt_D, color=color, marker="s", s=60, edgecolors="black", zorder=5) + ax.plot( + B_smooth, D_smooth, color=color, linestyle="--", alpha=0.7, label=f"{name}: $D^* \\propto C^{{{m_D:.2f}}}$" + ) + + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_xlabel("Compute Budget (FLOPs)") + ax.set_ylabel("Optimal Tokens (D*)") + ax.set_title("Optimal Tokens vs Compute") + ax.grid(True, which="both", ls="-", alpha=0.2) + ax.legend(fontsize=7, loc="upper left", framealpha=0.5) + + +def create_combined_figure(data: dict[str, pd.DataFrame], architecture: str = DEFAULT_ARCH): + """Create a 4-facet figure showing isoflop analysis grouped by compute range. + + Layout (2x2): + - Top-left: Loss vs Parameters + - Top-right: Loss vs Tokens + - Bottom-left: Optimal Params vs FLOPs + - Bottom-right: Optimal Tokens vs FLOPs + """ + range_data = prepare_data(data, architecture) + + if not range_data: + raise ValueError("No valid data found for any compute range") + + logger.info(f"Prepared data for {len(range_data)} compute ranges") + + # Create 2x2 figure + fig, axes = plt.subplots(2, 2, figsize=(12, 10)) + + # Top row: Loss curves + plot_loss_vs_variable(axes[0, 0], range_data, "params") + plot_loss_vs_variable(axes[0, 1], range_data, "tokens") + + # Bottom row: Scaling laws + plot_optimal_params_vs_flops(axes[1, 0], range_data) + plot_optimal_tokens_vs_flops(axes[1, 1], range_data) + + plt.suptitle("PlantCAD Isoflop Analysis: Bias by Compute Range (Step Count)", y=0.98, fontsize=14) + plt.tight_layout(rect=[0, 0, 1, 0.96]) + + return fig + + +def save_combined_data(data: dict[str, pd.DataFrame], output_path: Path) -> None: + """Save all combined data to CSV.""" + all_dfs = [df for _version, _name, _steps, df in iter_filtered_data(data)] + + if all_dfs: + combined = pd.concat(all_dfs, ignore_index=True) + combined.to_csv(output_path, index=False) + logger.info(f"Saved {len(combined)} rows to {output_path}") + + +def main(): + """Main entry point.""" + logger.info("Loading CSV files...") + data = load_all_csvs() + + if not data: + raise ValueError("No CSV files found") + + # Print sorted budget summary + print_budget_summary(data) + + logger.info(f"Creating combined figure from {len(data)} compute ranges...") + fig = create_combined_figure(data) + + # Create output directory + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + + # Save figure as PNG and PDF + png_path = OUTPUT_DIR / f"{OUTPUT_BASE}.png" + pdf_path = OUTPUT_DIR / f"{OUTPUT_BASE}.pdf" + csv_path = OUTPUT_DIR / f"{OUTPUT_BASE}.csv" + + fig.savefig(png_path, dpi=EXPORT_DPI, bbox_inches="tight") + logger.info(f"Saved figure to {png_path}") + + fig.savefig(pdf_path, dpi=EXPORT_DPI, bbox_inches="tight") + logger.info(f"Saved figure to {pdf_path}") + + plt.close(fig) + + # Save combined data as CSV + save_combined_data(data, csv_path) + + +if __name__ == "__main__": + main() diff --git a/experiments/plantcad/exp2101_plantcad_multi_isoflop_sweep.py b/experiments/plantcad/exp2101_plantcad_multi_isoflop_sweep.py new file mode 100644 index 0000000000..df6a77a594 --- /dev/null +++ b/experiments/plantcad/exp2101_plantcad_multi_isoflop_sweep.py @@ -0,0 +1,701 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generate ISOFlop sweep steps for varying model sizes, architectures and epochs on a target DNA dataset.""" + +import os +import math +import logging +import dataclasses +import numpy as np +import pandas as pd +from collections.abc import Iterator, Sequence +from dataclasses import dataclass, replace + +from levanter.data.text import TextLmDatasetFormat +from levanter.models.llama import LlamaConfig +from levanter.models.qwen import Qwen3Config +from levanter.optim.cautious import CautiousConfig +from levanter.optim.config import OptimizerConfig +from levanter.utils.flop_utils import lm_flops_per_token +from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig + +from experiments.defaults import default_train, _prepare_data_config +from experiments.evals.task_configs import EvalTaskConfig +from experiments.llama import compute_num_parameters +from experiments.simple_train_config import SimpleTrainConfig +from marin.execution.executor import ExecutorStep, InputName, executor_main, this_output_path, versioned +from marin.processing.tokenize.tokenize import HfTokenizeConfig, tokenize +from fray.cluster import ResourceConfig + +logger = logging.getLogger("ray") + +# TPU v5p hardware constants for memory estimation +# Constants for TPU v5p +HBM_PER_CHIP_GIB = 95 +CORES_PER_CHIP = 2 +V5P_CORE_OPTIONS = [8, 16, 32, 128, 256, 512] # TPU slices + +ModelConfig = LlamaConfig | Qwen3Config + + +def simulated_epoch_train( + name: str, + tokenized: InputName | ExecutorStep, + model_config: ModelConfig, + train_config: "SimpleTrainConfig", + train_tokens: int, + dataset_tokens: int, + epoch_count: int = 1, + tags: Sequence[str] = (), + use_default_validation: bool = False, + eval_harness_tasks: Sequence[EvalTaskConfig] = (), +) -> ExecutorStep: + """Train with simulated epoching. When epoch_count=1, uses full dataset.""" + if not isinstance(epoch_count, int) or epoch_count < 1: + raise ValueError(f"epoch_count must be int >= 1, got {epoch_count}") + + pretraining_data = _prepare_data_config(tokenized, use_default_validation) + + if epoch_count == 1: + return default_train( + name, + tokenized=pretraining_data, + model_config=model_config, + train_config=train_config, + tags=tags, + use_default_validation=use_default_validation, + eval_harness_tasks=eval_harness_tasks, + ) + + # To use simulated epoching in Levanter, we need to first address the fact that + # we are already limiting training to less than 1 true epoch in each run. + # + # The Levanter formula for this feature takes two arguments, experiment_budget and target_budget, + # and then uses this formula to determine how to slice each epoch: + # + # simulated_data_ratio = experiment_budget / target_budget + # simulated_length_of_dataset = int(true_length_of_dataset * simulated_data_ratio) + # sliced_datasets[name] = ds.slice_dataset(end_index=simulated_length_of_dataset) + # + # See: https://github.com/marin-community/marin/blob/eb4acbdd185a34202da16052c46c74eb570e69a5/lib/levanter/src/levanter/data/text.py#L1273-L1280 + # + # This means that `simulated_data_ratio` must become equal to `train_tokens / dataset_tokens / epoch_count` + # in order for the simulated epochs to work on top of a partial epoch. + # We accomplish this here by setting: + # - experiment_budget = train_tokens + # - target_budget = dataset_tokens * epoch_count + experiment_budget, target_budget = train_tokens, dataset_tokens * epoch_count + simulated_pretraining_data = dataclasses.replace( + pretraining_data, target_budget=target_budget, experiment_budget=experiment_budget + ) + + return default_train( + name, + tokenized=simulated_pretraining_data, + model_config=model_config, + train_config=train_config, + tags=tags, + use_default_validation=use_default_validation, + eval_harness_tasks=eval_harness_tasks, + ) + + +def format_num(n: int | float) -> str: + """Format numbers in T/B/M/K notation (e.g., 1.5T, 100B, 5.2M, 1.0K).""" + if n >= 1_000_000_000_000: + return f"{n / 1_000_000_000_000:.1f}T" + elif n >= 1_000_000_000: + return f"{n / 1_000_000_000:.1f}B" + elif n >= 10_000_000: + return f"{int(n / 1_000_000)}M" + elif n >= 1_000_000: + return f"{n / 1_000_000:.1f}M" + elif n >= 1_000: + return f"{n / 1_000:.1f}K" + return str(int(n)) + + +@dataclass(frozen=True) +class IsoFlopDataConfig: + dataset_name: str = versioned("plantcad/plantcad2-c4096") + dataset_revision: str | None = versioned("e551075a775a2fd2c04f0f97741dc68b09bde653") + seq_len: int = 4096 + total_token_count: int = 10_807_934_976 # 2,638,656 examples * 4,096 in train split + + +@dataclass(frozen=True) +class IsoFlopTokenizeConfig(HfTokenizeConfig): + tokenizer: str = versioned("kuleshov-group/PlantCAD2-Small-l24-d0768") + format: TextLmDatasetFormat = dataclasses.field(default_factory=lambda: TextLmDatasetFormat(text_key="seq")) + vocab_size: int = 7 + + +@dataclass(frozen=True) +class IsoFlopSweepParams: + """Configuration for a specific compute range in the ISOFlop sweep.""" + + experiment_name: str + compute_range_name: str + budgets: list[float] + steps_per_run: int + hidden_step_size: int = 112 # 112 + hidden_head_ratio: int = 128 # 128 + + +# Predefined compute range configurations +ISOFLOP_SWEEPS = { + # minimal-compute (1x): 5e15 to 6.25e16, steps=2_048 (v2.3) + # "minimal": IsoFlopSweepParams( + # experiment_name="plantcad_isoflop_v2.3", + # compute_range_name="minimal", + # budgets=list(np.logspace(np.log10(5e15), np.log10(6.25e16), 5)), + # steps_per_run=2_048, + # ), + # # very-low-compute (2x): 1e16 to 1.25e17, steps=4_096 (v2.6) + # "very-low": IsoFlopSweepParams( + # experiment_name="plantcad_isoflop_v2.6", + # compute_range_name="very-low", + # budgets=list(np.logspace(np.log10(1e16), np.log10(1.25e17), 5)), + # steps_per_run=4_096, + # ), + # # low-compute (4x): 2e16 to 2.5e17, steps=8_192 (v2.4) + # "low": IsoFlopSweepParams( + # experiment_name="plantcad_isoflop_v2.4", + # compute_range_name="low", + # budgets=list(np.logspace(np.log10(2e16), np.log10(2.5e17), 5)), + # steps_per_run=8_192, + # ), + # mid-compute (8x): 4e16 to 5e17, steps=16_384 (v2.5) + # "mid": IsoFlopSweepParams( + # experiment_name="plantcad_isoflop_v2.5", + # compute_range_name="mid", + # budgets=list(np.logspace(np.log10(4e16), np.log10(5e17), 5)), + # steps_per_run=16_384, + # ), + # high-compute (16x): 8e16 to 1e18, steps=32_768 (v2.2) + "high": IsoFlopSweepParams( + experiment_name="plantcad_isoflop_v2.2", + compute_range_name="high", + budgets=list(np.logspace(np.log10(8e16), np.log10(1e18), 5)), + steps_per_run=32_768, + ), +} + + +@dataclass(frozen=True) +class IsoFlopSweepConfig: + tokenized_dataset: InputName | str + vocab_size: int + seq_len: int + total_token_count: int + experiment_name: str + compute_range_name: str + budgets: list[float] + steps_per_run: int + hidden_step_size: int + hidden_head_ratio: int + + epochs: list[int] = dataclasses.field(default_factory=lambda: [1]) + min_hidden_pow: int = 8 + max_hidden_pow: int = 10 + mlp_ratio: int = 4 + base_hidden_layer_ratio: int = 64 + lr_max: float | None = 0.03 + flop_tolerance: float = 0.01 + architectures: list[str] = dataclasses.field(default_factory=lambda: ["qwen"]) + # TODO: adjust eval example count to account for num tpus in the even that v5p-8 is not used + per_device_eval_parallelism: int = 512 + max_eval_batches: int = 64 + num_evals: int = 3 + + lr_constant: float = 0.33 + base_optimizer_config: OptimizerConfig = dataclasses.field( + default_factory=lambda: CautiousConfig( + learning_rate=1.0, + weight_decay=0.1, + min_lr_ratio=0.0, + warmup=0.1, + beta1=0.95, + beta2=0.98, + epsilon=1e-15, + max_grad_norm=1, + adamc_weight_decay=True, + lr_schedule="linear", + decay=0.2, + ), + ) + base_train_config: SimpleTrainConfig = dataclasses.field( + default_factory=lambda: SimpleTrainConfig( + resources=ResourceConfig.with_tpu("v5p-8"), + train_batch_size=1, + num_train_steps=50_000, + learning_rate=1.0, + weight_decay=0.1, + min_lr_ratio=0.0, + lr_schedule="linear", + decay=0.2, + ) + ) + + +def estimate_bytes( + param_count: int, + hidden_dim: int, + num_layers: int, + batch: int, + seq_len: int, + vocab: int, + optim_mult: int = 3, + dtype_size: int = 4, + fudge_factor: float = 2, +) -> int: + """Estimate memory usage (in bytes) for one training step.""" + param_bytes = param_count * optim_mult * dtype_size + + act_bytes = (batch * seq_len) * ((hidden_dim * num_layers) + vocab * fudge_factor) + + total_bytes = param_bytes + act_bytes + return int(total_bytes) * fudge_factor + + +def pick_v5p_type( + config: ModelConfig, + hidden: int, + layers: int, + batch: int, + seq_len: int, + vocab: int, +) -> str: + """Select the smallest TPU v5p slice that fits the model in float32.""" + param_count = compute_num_parameters(config, vocab) + need_bytes = estimate_bytes(param_count, hidden, layers, batch, seq_len, vocab) + chip_bytes = HBM_PER_CHIP_GIB * 1024**3 + chips = math.ceil(need_bytes / chip_bytes) + cores_req = chips * CORES_PER_CHIP + + valid = [c for c in V5P_CORE_OPTIONS if c >= cores_req] + if not valid: + raise ValueError(f"Model too large for available v5p slices (need {cores_req} cores).") + + return f"v5p-{min(valid)}" + + +def round_to_power_of_two(x: float) -> int: + """Round ``x`` to the nearest power of two.""" + if x <= 1: + return 1 + return 2 ** math.ceil(math.log2(x)) + + +def compute_total_flops( + batch: int, + num_layers: int, + hidden: int, + intermediate: int, + num_kv_heads: int, + num_heads: int, + steps: int, + seq_len: int, + vocab_size: int, +) -> float: + """Compute total training FLOPs using Levanter utilities.""" + + flops_per_token = lm_flops_per_token( + hidden, + intermediate, + num_layers, + num_kv_heads, + num_heads, + seq_len, + vocab_size, + glu=True, + ) + return flops_per_token * batch * steps * seq_len + + +@dataclass +class IsoFlopRunConfig: + experiment_name: str + compute_range_name: str + steps_per_run: int + hidden_step_size: int + architecture: str + hidden_size: int + intermediate_dim: int + num_layers: int + n_heads: int + n_kv_heads: int + batch_size: int + batch_exact: float + batch_rounded: int + train_steps: int + lr: float + beta2: float + budget: float + steps_per_eval: int + train_tokens: int + dataset_tokens: int + num_params: int + epoch_count: int + model_config: ModelConfig + + +def generate_run_configs(cfg: IsoFlopSweepConfig, budget: float) -> Iterator[IsoFlopRunConfig]: + """Generate ISOFlop run configurations within the FLOP budget.""" + + dataset_tokens = cfg.total_token_count + + # Loop over architecture as the primary dimension of the search space + for architecture in cfg.architectures: + # Loop through hidden size on a grid, which will determine the model + # size and therefore token count for each run config + for hidden_size in range(2**cfg.min_hidden_pow, (2**cfg.max_hidden_pow) + 1, cfg.hidden_step_size): + hs_pow = math.log2(hidden_size) + intermediate_dim = hidden_size * cfg.mlp_ratio + num_layers = round(hidden_size / (cfg.base_hidden_layer_ratio + (hs_pow * 4) - cfg.min_hidden_pow)) + assert ( + hidden_size % cfg.hidden_head_ratio == 0 + ), f"hidden_size ({hidden_size}) must be evenly divisible by hidden_head_ratio ({cfg.hidden_head_ratio})" + n_heads = max(1, hidden_size // cfg.hidden_head_ratio) + n_kv_heads = n_heads + + # Calculate batch size to meet budget with fixed steps + batch_exact = budget / compute_total_flops( + batch=1, + num_layers=num_layers, + hidden=hidden_size, + intermediate=intermediate_dim, + num_kv_heads=n_kv_heads, + num_heads=n_heads, + steps=cfg.steps_per_run, + seq_len=cfg.seq_len, + vocab_size=cfg.vocab_size, + ) + + batch_rounded = round_to_power_of_two(batch_exact) + batch_size = batch_rounded + + # Scale LR with sqrt(batch) and hidden size + # Reference: https://arxiv.org/pdf/2203.03466 (Section 10 Related Works) + lr = (cfg.lr_constant * math.sqrt(batch_size)) / hidden_size + + # Halve batch size until LR is stable + if cfg.lr_max is not None: + while lr > cfg.lr_max: + old_batch = batch_size + batch_size //= 2 + lr = (cfg.lr_constant * math.sqrt(batch_size)) / hidden_size + logger.warning( + f"Halving batch size for ({architecture=}, {hidden_size=}): " + f"{old_batch} -> {batch_size} (lr={lr:.4f}, lr_max={cfg.lr_max})" + ) + + # Set beta2 based on https://arxiv.org/pdf/2507.07101 + b2 = 0.98 ** (batch_size / 128) + + if batch_size < 8: + logger.warning( + f"Skipping config for ({budget=:.1e}, {architecture=}, {hidden_size=}) " + f"with batch size {batch_size} (less than 8)" + ) + continue + + # Recompute exact steps based on adjusted batch size + steps_exact = budget / compute_total_flops( + batch=batch_size, + num_layers=num_layers, + hidden=hidden_size, + intermediate=intermediate_dim, + num_kv_heads=n_kv_heads, + num_heads=n_heads, + steps=1, + seq_len=cfg.seq_len, + vocab_size=cfg.vocab_size, + ) + train_steps = round(steps_exact) + + # Ensure actual flops still within range + achieved_flops = compute_total_flops( + batch=batch_size, + num_layers=num_layers, + hidden=hidden_size, + intermediate=intermediate_dim, + num_kv_heads=n_kv_heads, + num_heads=n_heads, + steps=train_steps, + seq_len=cfg.seq_len, + vocab_size=cfg.vocab_size, + ) + if abs(achieved_flops - budget) / budget > cfg.flop_tolerance: + logger.warning( + f"Skipping config for ({budget=:.1e}, {architecture=}, {hidden_size=}) " + f"with achieved flops {achieved_flops} (not within {cfg.flop_tolerance} of budget {budget})" + ) + continue + + train_tokens = train_steps * batch_size * cfg.seq_len + # Subtract 1 from num_evals to account for the first evaluation + num_evals = max(1, cfg.num_evals - 1) + steps_per_eval = max(1, train_steps // num_evals) + + if train_tokens > dataset_tokens: + logger.warning( + f"Skipping config for ({budget=:.1e}, {architecture=}, {hidden_size=}) " + f"with train tokens {train_tokens} (greater than dataset tokens {dataset_tokens})" + ) + continue + + if architecture == "llama": + model_cfg = LlamaConfig( + max_seq_len=cfg.seq_len, + hidden_dim=hidden_size, + intermediate_dim=intermediate_dim, + num_heads=n_heads, + num_kv_heads=n_kv_heads, + num_layers=num_layers, + ) + elif architecture == "qwen": + model_cfg = Qwen3Config( + max_seq_len=cfg.seq_len, + hidden_dim=hidden_size, + intermediate_dim=intermediate_dim, + num_heads=n_heads, + num_kv_heads=n_kv_heads, + num_layers=num_layers, + rope=Llama3RotaryEmbeddingsConfig(), + ) + else: + raise ValueError(f"Unknown architecture: {architecture}") + + num_params = compute_num_parameters(model_cfg, cfg.vocab_size) + + for epoch_count in cfg.epochs: + yield IsoFlopRunConfig( + experiment_name=cfg.experiment_name, + compute_range_name=cfg.compute_range_name, + steps_per_run=cfg.steps_per_run, + hidden_step_size=cfg.hidden_step_size, + architecture=architecture, + hidden_size=hidden_size, + intermediate_dim=intermediate_dim, + num_layers=num_layers, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + batch_size=batch_size, + batch_exact=batch_exact, + batch_rounded=batch_rounded, + train_steps=train_steps, + lr=lr, + beta2=b2, + budget=budget, + steps_per_eval=steps_per_eval, + train_tokens=train_tokens, + dataset_tokens=dataset_tokens, + num_params=num_params, + epoch_count=epoch_count, + model_config=model_cfg, + ) + + +def _log_isoflop_run_configs(all_configs: list[IsoFlopRunConfig]): + """Log summary of generated ISOFlop configurations.""" + if all_configs: + df = pd.DataFrame([dataclasses.asdict(c) for c in all_configs]) + + # Format large numbers for readability + if "num_params" in df.columns: + df["num_params"] = df["num_params"].apply(format_num) + if "train_tokens" in df.columns: + df["train_tokens"] = df["train_tokens"].apply(format_num) + + pd.set_option("display.max_rows", None) + pd.set_option("display.max_columns", None) + pd.set_option("display.width", 1000) + + logger.info("\n" + "=" * 80) + logger.info("Configuration Summary Dataframe") + logger.info("=" * 80) + logger.info("\n" + str(df.drop(columns=["model_config"]))) + + logger.info("\n" + "=" * 50) + logger.info("Configs per Budget") + logger.info("=" * 50) + logger.info("\n" + str(df.groupby("budget").size())) + + logger.info("\n" + "=" * 50) + logger.info("Configs per Architecture") + logger.info("=" * 50) + logger.info("\n" + str(df.groupby("architecture").size())) + + logger.info("=" * 50 + "\n") + else: + logger.warning("No configurations generated!") + + +def generate_isoflop_steps(config: IsoFlopSweepConfig) -> list[ExecutorStep]: + """Generate executor steps for an ISOFlop sweep.""" + + # Collect all run configs first + all_configs: list[IsoFlopRunConfig] = [] + + for budget in config.budgets: + for c in generate_run_configs(config, budget): + all_configs.append(c) + + _log_isoflop_run_configs(all_configs) + + # Generate executor steps from validated configs + steps: list[ExecutorStep] = [] + for c in all_configs: + # Use the pre-computed model config + model_cfg = c.model_config + + tpu_type = pick_v5p_type( + config=model_cfg, + hidden=c.hidden_size, + layers=c.num_layers, + batch=c.batch_size, + seq_len=config.seq_len, + vocab=config.vocab_size, + ) + optimizer_cfg = replace(config.base_optimizer_config, learning_rate=c.lr, beta2=c.beta2) + train_cfg = replace( + config.base_train_config, + train_batch_size=c.batch_size, + learning_rate=c.lr, + num_train_steps=c.train_steps, + steps_per_eval=c.steps_per_eval, + per_device_eval_parallelism=config.per_device_eval_parallelism, + max_eval_batches=config.max_eval_batches, + resources=ResourceConfig.with_tpu(tpu_type), + optimizer_config=optimizer_cfg, + ) + + param_count = c.num_params + step = simulated_epoch_train( + name="-".join( + [ + config.experiment_name, + f"A_{c.architecture}", + f"F{c.budget:.1e}", + f"P{format_num(param_count)}", + f"T{format_num(c.train_tokens)}", + f"E{c.epoch_count}", + ] + ), + tokenized=config.tokenized_dataset, + model_config=model_cfg, + train_config=train_cfg, + train_tokens=c.train_tokens, + dataset_tokens=c.dataset_tokens, + epoch_count=c.epoch_count, + eval_harness_tasks=[], + use_default_validation=False, + tags=( + f"architecture={c.architecture}", + f"flops_budget={c.budget:.1e}", + f"hidden_size={c.hidden_size}", + f"num_layers={c.num_layers}", + f"batch_size={c.batch_size}", + f"steps={c.train_steps}", + f"tpu={tpu_type}", + f"params={param_count}", + f"tokens={c.train_tokens}", + f"epochs={c.epoch_count}", + f"vocab_size={config.vocab_size}", + ), + ) + steps.append(step) + + return steps + + +def generate_isoflop_sweeps( + tokenized: ExecutorStep, + vocab_size: int, + seq_len: int, + total_token_count: int, +) -> list[ExecutorStep]: + """Generate executor steps for all ISOFlop sweeps.""" + all_steps: list[ExecutorStep] = [] + for sweep_params in ISOFLOP_SWEEPS.values(): + sweep_cfg = IsoFlopSweepConfig( + tokenized_dataset=tokenized, + vocab_size=vocab_size, + seq_len=seq_len, + total_token_count=total_token_count, + experiment_name=sweep_params.experiment_name, + compute_range_name=sweep_params.compute_range_name, + budgets=sweep_params.budgets, + steps_per_run=sweep_params.steps_per_run, + hidden_step_size=sweep_params.hidden_step_size, + hidden_head_ratio=sweep_params.hidden_head_ratio, + ) + steps = generate_isoflop_steps(sweep_cfg) + all_steps.extend(steps) + return all_steps + + +def tokenize_plantcad() -> tuple[ExecutorStep, IsoFlopDataConfig]: + """Tokenize the PlantCAD dataset directly from HuggingFace. + + Returns: + A tuple of (tokenized ExecutorStep, IsoFlopDataConfig with dataset metadata). + """ + data_config = IsoFlopDataConfig() + step = ExecutorStep( + name="tokenized/plantcad2", + fn=tokenize, + config=IsoFlopTokenizeConfig( + id=data_config.dataset_name, + revision=data_config.dataset_revision, + cache_path=this_output_path(), + window_size_bytes=50_000_000, + ), + ) + return step, data_config + + +def main(): + plantcad_tokenized, data_config = tokenize_plantcad() + + # Generate sweep steps + plantcad_sweep = generate_isoflop_sweeps( + tokenized=plantcad_tokenized, + vocab_size=plantcad_tokenized.config.vocab_size, + seq_len=data_config.seq_len, + total_token_count=data_config.total_token_count, + ) + + # Execute in batches of 16 sweep steps to avoid head node OOM errors. + # See: https://discord.com/channels/1354881461060243556/1442689455554171071/1447914001957785620 + # TODO: Figure out how to run on compute nodes instead w/o reverting back to this PR: + # https://discord.com/channels/1354881461060243556/1442689455554171071/1447920947402375291 + batch_size = 16 + batches = [plantcad_sweep[i : i + batch_size] for i in range(0, len(plantcad_sweep), batch_size)] + batch_index: int | None = int(os.environ["SWEEP_BATCH_INDEX"]) if "SWEEP_BATCH_INDEX" in os.environ else None + if batch_index is not None: + logger.info(f"SWEEP_BATCH_INDEX={batch_index}; running batch {batch_index + 1} of {len(batches)}") + batches = [batches[int(batch_index)]] + for i, batch in enumerate(batches): + logger.info(f"Running batch {i + 1}/{len(batches)} with {len(batch)} sweep steps") + if os.environ.get("DRY_RUN"): + logger.info("DRY RUN; skipping execution") + continue + executor_main(steps=[plantcad_tokenized, *batch]) + + +if __name__ == "__main__": + main() diff --git a/experiments/plantcad/exp2101_text_isoflop_sweep.py b/experiments/plantcad/exp2101_text_isoflop_sweep.py new file mode 100644 index 0000000000..79aac63597 --- /dev/null +++ b/experiments/plantcad/exp2101_text_isoflop_sweep.py @@ -0,0 +1,675 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generate ISOFlop sweep steps for model sizes, architectures, and epochs on pre-tokenized text.""" + +import os +import math +import logging +import dataclasses +import numpy as np +import pandas as pd +from collections.abc import Iterator, Sequence +from dataclasses import dataclass, replace + +from levanter.data.text import LMMixtureDatasetConfig, UrlDatasetSourceConfig +from levanter.models.llama import LlamaConfig +from levanter.models.qwen import Qwen3Config +from levanter.optim.cautious import CautiousConfig +from levanter.optim.config import OptimizerConfig +from levanter.utils.flop_utils import lm_flops_per_token +from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig + +from experiments.defaults import default_train +from experiments.evals.task_configs import EvalTaskConfig +from experiments.llama import llama3_tokenizer, llama3_tokenizer_vocab_size +from experiments.simple_train_config import SimpleTrainConfig +from marin.execution.executor import ExecutorStep, executor_main +from fray.cluster import ResourceConfig + +logger = logging.getLogger("ray") + +# TPU v5p hardware constants for memory estimation +# Constants for TPU v5p +HBM_PER_CHIP_GIB = 95 +CORES_PER_CHIP = 2 +V5P_CORE_OPTIONS = [8, 16, 32, 128, 256, 512] # TPU slices + +ModelConfig = LlamaConfig | Qwen3Config + + +def simulated_epoch_train( + name: str, + tokenized: LMMixtureDatasetConfig, + model_config: ModelConfig, + train_config: "SimpleTrainConfig", + train_tokens: int, + dataset_tokens: int, + epoch_count: int = 1, + tags: Sequence[str] = (), + eval_harness_tasks: Sequence[EvalTaskConfig] = (), +) -> ExecutorStep: + """Train with simulated epoching. When epoch_count=1, uses full dataset.""" + if not isinstance(epoch_count, int) or epoch_count < 1: + raise ValueError(f"epoch_count must be int >= 1, got {epoch_count}") + + pretraining_data = tokenized + + if epoch_count == 1: + return default_train( + name, + tokenized=pretraining_data, + model_config=model_config, + train_config=train_config, + tags=tags, + eval_harness_tasks=eval_harness_tasks, + use_default_validation=False, + ) + + # To use simulated epoching in Levanter, we need to first address the fact that + # we are already limiting training to less than 1 true epoch in each run. + # + # The Levanter formula for this feature takes two arguments, experiment_budget and target_budget, + # and then uses this formula to determine how to slice each epoch: + # + # simulated_data_ratio = experiment_budget / target_budget + # simulated_length_of_dataset = int(true_length_of_dataset * simulated_data_ratio) + # sliced_datasets[name] = ds.slice_dataset(end_index=simulated_length_of_dataset) + # + # See: https://github.com/marin-community/marin/blob/eb4acbdd185a34202da16052c46c74eb570e69a5/lib/levanter/src/levanter/data/text.py#L1273-L1280 + # + # This means that `simulated_data_ratio` must become equal to `train_tokens / dataset_tokens / epoch_count` + # in order for the simulated epochs to work on top of a partial epoch. + # We accomplish this here by setting: + # - experiment_budget = train_tokens + # - target_budget = dataset_tokens * epoch_count + experiment_budget, target_budget = train_tokens, dataset_tokens * epoch_count + simulated_pretraining_data = dataclasses.replace( + pretraining_data, target_budget=target_budget, experiment_budget=experiment_budget + ) + + return default_train( + name, + tokenized=simulated_pretraining_data, + model_config=model_config, + train_config=train_config, + tags=tags, + eval_harness_tasks=eval_harness_tasks, + use_default_validation=False, + ) + + +def format_num(n: int | float) -> str: + """Format numbers in T/B/M/K notation (e.g., 1.5T, 100B, 5.2M, 1.0K).""" + if n >= 1_000_000_000_000: + return f"{n / 1_000_000_000_000:.1f}T" + elif n >= 1_000_000_000: + return f"{n / 1_000_000_000:.1f}B" + elif n >= 10_000_000: + return f"{int(n / 1_000_000)}M" + elif n >= 1_000_000: + return f"{n / 1_000_000:.1f}M" + elif n >= 1_000: + return f"{n / 1_000:.1f}K" + return str(int(n)) + + +@dataclass(frozen=True) +class IsoFlopDataConfig: + # Use pretokenized DCLM baseline from GCS (control experiment) + # Absolute path in same region used to avoid re-downloading when running with a different MARIN_PREFIX + # (i.e. marin-dna-us-central1) + tokenized_path: str = "gs://marin-us-central1/tokenized/dclm_baseline-0206f1" + tokenizer: str = llama3_tokenizer + vocab_size: int = llama3_tokenizer_vocab_size + seq_len: int = 4096 + # Keep consistent with PlantCAD total token count even though DCLM has far, far more + total_token_count: int = 2_600_000_000_000 # 2.6T + + +@dataclass(frozen=True) +class IsoFlopSweepConfig: + tokenized_dataset: LMMixtureDatasetConfig + vocab_size: int + seq_len: int + total_token_count: int + experiment_name: str = "plantcad_isoflop_v2.9" + budgets: list[float] = dataclasses.field( + # default_factory=lambda: list(np.logspace(np.log10(6.4e17), np.log10(8e18), 5)) + default_factory=lambda: list( + v for i, v in enumerate(np.logspace(np.log10(6.4e17), np.log10(8e18), 5)) if i in [0, 1, 4] + ) + ) + epochs: list[int] = dataclasses.field(default_factory=lambda: [1]) + steps_per_run: int = 65_536 + min_hidden_pow: int = 9 + max_hidden_pow: int = 12 + hidden_step_size: int = 128 + mlp_ratio: int = 4 + base_hidden_layer_ratio: int = 64 + hidden_head_ratio: int = 128 + lr_max: float | None = 0.03 + flop_tolerance: float = 0.01 + architectures: list[str] = dataclasses.field(default_factory=lambda: ["qwen"]) + # TODO: adjust eval example count to account for num tpus in the even that v5p-8 is not used + per_device_eval_parallelism: int = 8 + max_eval_batches: int = 1024 + num_evals: int = 3 + lr_constant: float = 0.33 + base_optimizer_config: OptimizerConfig = dataclasses.field( + default_factory=lambda: CautiousConfig( + learning_rate=1.0, + weight_decay=0.1, + min_lr_ratio=0.0, + warmup=0.1, + beta1=0.95, + beta2=0.98, + epsilon=1e-15, + max_grad_norm=1, + adamc_weight_decay=True, + lr_schedule="linear", + decay=0.2, + ), + ) + base_train_config: SimpleTrainConfig = dataclasses.field( + default_factory=lambda: SimpleTrainConfig( + resources=ResourceConfig.with_tpu("v5p-8"), + train_batch_size=1, + num_train_steps=50_000, + learning_rate=1.0, + weight_decay=0.1, + min_lr_ratio=0.0, + lr_schedule="linear", + decay=0.2, + ) + ) + + +def estimate_bytes( + param_count: int, + hidden_dim: int, + num_layers: int, + batch: int, + seq_len: int, + vocab: int, + optim_mult: int = 3, + dtype_size: int = 4, + # TODO: determine why this needs to be higher than the original 2 to avoid OOMs on DCLM + fudge_factor: float = 4, +) -> int: + """Estimate memory usage (in bytes) for one training step.""" + param_bytes = param_count * optim_mult * dtype_size + + act_bytes = (batch * seq_len) * ((hidden_dim * num_layers) + vocab * fudge_factor) + + total_bytes = param_bytes + act_bytes + return int(total_bytes) * fudge_factor + + +def pick_v5p_type( + config: ModelConfig, + hidden: int, + layers: int, + batch: int, + seq_len: int, + vocab: int, +) -> str: + """Select the smallest TPU v5p slice that fits the model in float32.""" + _, _, param_count = compute_param_count(config, vocab) + need_bytes = estimate_bytes(param_count, hidden, layers, batch, seq_len, vocab) + chip_bytes = HBM_PER_CHIP_GIB * 1024**3 + chips = math.ceil(need_bytes / chip_bytes) + cores_req = chips * CORES_PER_CHIP + + valid = [c for c in V5P_CORE_OPTIONS if c >= cores_req] + if not valid: + raise ValueError(f"Model too large for available v5p slices (need {cores_req} cores).") + + return f"v5p-{min(valid)}" + + +def round_to_power_of_two(x: float) -> int: + """Round ``x`` to the nearest power of two.""" + if x <= 1: + return 1 + return 2 ** math.ceil(math.log2(x)) + + +def compute_param_count(config: LlamaConfig, vocab_size: int) -> tuple[int, int, int]: + # Copied from compute_num_parameters in experiments/llama.py and modified to + # return multiple results; see: + # https://github.com/marin-community/marin/blob/bc58ab8ee62ba5e38ce4f1e2d7d64271431be160/experiments/llama.py#L249-L267 + head_size = config.hidden_dim // config.num_heads + q_params = config.num_heads * head_size * config.hidden_dim + k_params = config.num_kv_heads * head_size * config.hidden_dim + v_params = config.num_kv_heads * head_size * config.hidden_dim + o_params = config.num_heads * head_size * config.hidden_dim + attention_params = q_params + k_params + v_params + o_params + + layer_norm_params = 2 * config.hidden_dim + + gate_params = config.hidden_dim * config.intermediate_dim + up_params = config.hidden_dim * config.intermediate_dim + down_params = config.intermediate_dim * config.hidden_dim + mlp_params = gate_params + up_params + down_params + + nonembedding_params = config.num_layers * (attention_params + mlp_params + layer_norm_params) + embedding_params = 2 * vocab_size * config.hidden_dim + + return embedding_params, nonembedding_params, nonembedding_params + embedding_params + + +def compute_total_flops( + batch: int, + num_layers: int, + hidden: int, + intermediate: int, + num_kv_heads: int, + num_heads: int, + steps: int, + seq_len: int, + vocab_size: int, +) -> float: + """Compute total training FLOPs using Levanter utilities.""" + + flops_per_token = lm_flops_per_token( + hidden, + intermediate, + num_layers, + num_kv_heads, + num_heads, + seq_len, + vocab_size, + glu=True, + ) + return flops_per_token * batch * steps * seq_len + + +@dataclass +class IsoFlopRunConfig: + architecture: str + hidden_size: int + intermediate_dim: int + num_layers: int + n_heads: int + n_kv_heads: int + batch_size: int + batch_target: int # Original batch size before LR-based halving + train_steps: int + lr: float + beta2: float + budget: float + steps_per_eval: int + train_tokens: int + dataset_tokens: int + num_params: int + embed_params: int + epoch_count: int + model_config: ModelConfig + + +def generate_run_configs(cfg: IsoFlopSweepConfig, budget: float) -> Iterator[IsoFlopRunConfig]: + """Generate ISOFlop run configurations within the FLOP budget.""" + + dataset_tokens = cfg.total_token_count + + # Loop over architecture as the primary dimension of the search space + for architecture in cfg.architectures: + # Loop through hidden size on a grid, which will determine the model + # size and therefore token count for each run config + for hidden_size in range(2**cfg.min_hidden_pow, (2**cfg.max_hidden_pow) + 1, cfg.hidden_step_size): + hs_pow = math.log2(hidden_size) + intermediate_dim = hidden_size * cfg.mlp_ratio + num_layers = round(hidden_size / (cfg.base_hidden_layer_ratio + (hs_pow * 4) - cfg.min_hidden_pow)) + assert ( + hidden_size % cfg.hidden_head_ratio == 0 + ), f"hidden_size ({hidden_size}) must be evenly divisible by hidden_head_ratio ({cfg.hidden_head_ratio})" + n_heads = max(1, hidden_size // cfg.hidden_head_ratio) + n_kv_heads = n_heads + + # Calculate batch size to meet budget with fixed steps + batch_exact_val = budget / compute_total_flops( + batch=1, + num_layers=num_layers, + hidden=hidden_size, + intermediate=intermediate_dim, + num_kv_heads=n_kv_heads, + num_heads=n_heads, + steps=cfg.steps_per_run, + seq_len=cfg.seq_len, + vocab_size=cfg.vocab_size, + ) + + batch_size = round_to_power_of_two(batch_exact_val) + batch_target = batch_size # Store original before LR-based halving + + # Scale LR with sqrt(batch) and hidden size + # Reference: https://arxiv.org/pdf/2203.03466 (Section 10 Related Works) + lr = (cfg.lr_constant * math.sqrt(batch_size)) / hidden_size + + # Halve batch size until LR is stable + if cfg.lr_max is not None: + while lr > cfg.lr_max: + logger.warning( + f"Halving batch size for ({architecture=}, {hidden_size=}): " + f"{batch_size} -> {batch_size // 2} (lr={lr:.4f}, lr_max={cfg.lr_max})" + ) + batch_size //= 2 + lr = (cfg.lr_constant * math.sqrt(batch_size)) / hidden_size + + # Set beta2 based on https://arxiv.org/pdf/2507.07101 + b2 = 0.98 ** (batch_size / 128) + + if batch_size < 8: + logger.warning( + f"Skipping config for ({architecture=}, {hidden_size=}) " + f"with batch size {batch_size} (less than 8)" + ) + continue + + # Recompute exact steps based on adjusted batch size + steps_exact = budget / compute_total_flops( + batch=batch_size, + num_layers=num_layers, + hidden=hidden_size, + intermediate=intermediate_dim, + num_kv_heads=n_kv_heads, + num_heads=n_heads, + steps=1, + seq_len=cfg.seq_len, + vocab_size=cfg.vocab_size, + ) + train_steps = round(steps_exact) + + # Ensure actual flops still within range + achieved_flops = compute_total_flops( + batch=batch_size, + num_layers=num_layers, + hidden=hidden_size, + intermediate=intermediate_dim, + num_kv_heads=n_kv_heads, + num_heads=n_heads, + steps=train_steps, + seq_len=cfg.seq_len, + vocab_size=cfg.vocab_size, + ) + if abs(achieved_flops - budget) / budget > cfg.flop_tolerance: + logger.warning( + f"Skipping config for ({architecture=}, {hidden_size=}) with achieved flops {achieved_flops} " + f"(not within {cfg.flop_tolerance} of budget {budget})" + ) + continue + + train_tokens = train_steps * batch_size * cfg.seq_len + # Subtract 1 from num_evals to account for the first evaluation + num_evals = max(1, cfg.num_evals - 1) + steps_per_eval = max(1, train_steps // num_evals) + + if train_tokens > dataset_tokens: + logger.warning( + f"Skipping config for ({architecture=}, {hidden_size=}) with train tokens {train_tokens} " + f"(greater than dataset tokens {dataset_tokens})" + ) + continue + + if architecture == "llama": + model_cfg = LlamaConfig( + max_seq_len=cfg.seq_len, + hidden_dim=hidden_size, + intermediate_dim=intermediate_dim, + num_heads=n_heads, + num_kv_heads=n_kv_heads, + num_layers=num_layers, + ) + elif architecture == "qwen": + model_cfg = Qwen3Config( + max_seq_len=cfg.seq_len, + hidden_dim=hidden_size, + intermediate_dim=intermediate_dim, + num_heads=n_heads, + num_kv_heads=n_kv_heads, + num_layers=num_layers, + rope=Llama3RotaryEmbeddingsConfig(), + ) + else: + raise ValueError(f"Unknown architecture: {architecture}") + + # num_params = compute_num_parameters(model_cfg, cfg.vocab_size) + embed_params, _, num_params = compute_param_count(model_cfg, cfg.vocab_size) + + for epoch_count in cfg.epochs: + yield IsoFlopRunConfig( + architecture=architecture, + hidden_size=hidden_size, + intermediate_dim=intermediate_dim, + num_layers=num_layers, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + batch_size=batch_size, + batch_target=batch_target, + train_steps=train_steps, + lr=lr, + beta2=b2, + budget=budget, + steps_per_eval=steps_per_eval, + train_tokens=train_tokens, + dataset_tokens=dataset_tokens, + num_params=num_params, + embed_params=embed_params, + epoch_count=epoch_count, + model_config=model_cfg, + ) + + +def _log_isoflop_run_configs(all_configs: list[IsoFlopRunConfig]): + """Log summary of generated ISOFlop configurations.""" + if all_configs: + df = pd.DataFrame([dataclasses.asdict(c) for c in all_configs]) + + # Format large numbers for readability + if "num_params" in df.columns: + df["num_params"] = df["num_params"].apply(format_num) + if "embed_params" in df.columns: + df["embed_params"] = df["embed_params"].apply(format_num) + if "train_tokens" in df.columns: + df["train_tokens"] = df["train_tokens"].apply(format_num) + + pd.set_option("display.max_rows", None) + pd.set_option("display.max_columns", None) + pd.set_option("display.width", 1000) + + logger.info("\n" + "=" * 80) + logger.info("Configuration Summary Dataframe") + logger.info("=" * 80) + logger.info("\n" + str(df.drop(columns=["model_config"]))) + + logger.info("\n" + "=" * 50) + logger.info("Configs per Budget") + logger.info("=" * 50) + logger.info("\n" + str(df.groupby("budget").size())) + + logger.info("\n" + "=" * 50) + logger.info("Configs per Architecture") + logger.info("=" * 50) + logger.info("\n" + str(df.groupby("architecture").size())) + + # Create table of unique param count mappings + logger.info("\n" + "=" * 50) + logger.info("Param Count Mapping") + logger.info("=" * 50) + raw_df = pd.DataFrame([dataclasses.asdict(c) for c in all_configs]) + param_mapping = raw_df[["num_params", "embed_params"]].drop_duplicates() + param_mapping["nonembedding_params"] = param_mapping["num_params"] - param_mapping["embed_params"] + param_mapping["pct_embedding_params"] = ( + param_mapping["embed_params"] / param_mapping["num_params"] * 100 + ).round(2) + param_mapping = param_mapping[ + ["num_params", "embed_params", "nonembedding_params", "pct_embedding_params"] + ].sort_values("num_params") + param_mapping = param_mapping.rename( + columns={ + "num_params": "Total Params", + "embed_params": "Embedding Params", + "nonembedding_params": "Non-Embedding Params", + "pct_embedding_params": "% Embedding", + } + ) + logger.info("\n" + param_mapping.to_string(index=False)) + + logger.info("=" * 50 + "\n") + else: + logger.warning("No configurations generated!") + + +def generate_isoflop_steps(config: IsoFlopSweepConfig) -> list[ExecutorStep]: + """Generate executor steps for an ISOFlop sweep.""" + + # Collect all run configs first + all_configs: list[IsoFlopRunConfig] = [] + + for budget in config.budgets: + for c in generate_run_configs(config, budget): + all_configs.append(c) + + _log_isoflop_run_configs(all_configs) + + # Generate executor steps from validated configs + steps: list[ExecutorStep] = [] + for c in all_configs: + # Use the pre-computed model config + model_cfg = c.model_config + + tpu_type = pick_v5p_type( + config=model_cfg, + hidden=c.hidden_size, + layers=c.num_layers, + batch=c.batch_size, + seq_len=config.seq_len, + vocab=config.vocab_size, + ) + optimizer_cfg = replace(config.base_optimizer_config, learning_rate=c.lr, beta2=c.beta2) + train_cfg = replace( + config.base_train_config, + train_batch_size=c.batch_size, + learning_rate=c.lr, + num_train_steps=c.train_steps, + steps_per_eval=c.steps_per_eval, + per_device_eval_parallelism=config.per_device_eval_parallelism, + max_eval_batches=config.max_eval_batches, + resources=ResourceConfig.with_tpu(tpu_type), + optimizer_config=optimizer_cfg, + ) + + param_count = c.num_params + step = simulated_epoch_train( + name="-".join( + [ + config.experiment_name, + f"A_{c.architecture}", + f"F{c.budget:.1e}", + f"P{format_num(param_count)}", + f"T{format_num(c.train_tokens)}", + f"E{c.epoch_count}", + ] + ), + tokenized=config.tokenized_dataset, + model_config=model_cfg, + train_config=train_cfg, + train_tokens=c.train_tokens, + dataset_tokens=c.dataset_tokens, + epoch_count=c.epoch_count, + eval_harness_tasks=[], + tags=( + f"architecture={c.architecture}", + f"flops_budget={c.budget:.1e}", + f"hidden_size={c.hidden_size}", + f"num_layers={c.num_layers}", + f"batch_size={c.batch_size}", + f"steps={c.train_steps}", + f"tpu={tpu_type}", + f"params={param_count}", + f"params_embed={c.embed_params}", + f"params_nonembed={c.num_params - c.embed_params}", + f"tokens={c.train_tokens}", + f"epochs={c.epoch_count}", + ), + ) + steps.append(step) + + return steps + + +def generate_isoflop_sweep( + tokenized: LMMixtureDatasetConfig, + **kwargs, +) -> list[ExecutorStep]: + sweep_cfg = IsoFlopSweepConfig(tokenized_dataset=tokenized, **kwargs) + steps = generate_isoflop_steps(sweep_cfg) + + return steps + + +def get_data_config() -> tuple[LMMixtureDatasetConfig, IsoFlopDataConfig]: + """Use pretokenized DCLM baseline dataset from GCS. + + Returns: + A tuple of (LMMixtureDatasetConfig, IsoFlopDataConfig with dataset metadata). + """ + data_config = IsoFlopDataConfig() + # Create LMMixtureDatasetConfig directly with absolute GCS path (no download needed) + mixture_config = LMMixtureDatasetConfig( + configs={ + "dclm_baseline": UrlDatasetSourceConfig( + cache_dir=data_config.tokenized_path, + ) + }, + train_weights={"dclm_baseline": 1.0}, + tokenizer=data_config.tokenizer, + # TODO: reduce this after ruling out eval noise (back to ~1024) + num_validation_sequences={"dclm_baseline": 131_072}, + ) + return mixture_config, data_config + + +def main(): + mixture_config, data_config = get_data_config() + + # Generate sweep steps + plantcad_sweep = generate_isoflop_sweep( + tokenized=mixture_config, + vocab_size=data_config.vocab_size, + seq_len=data_config.seq_len, + total_token_count=data_config.total_token_count, + ) + + # Execute in batches of 16 sweep steps to avoid head node OOM errors. + # See: https://discord.com/channels/1354881461060243556/1442689455554171071/1447914001957785620 + # TODO: Figure out how to run on compute nodes instead w/o reverting back to this PR: + # https://discord.com/channels/1354881461060243556/1442689455554171071/1447920947402375291 + batch_size = 16 + batches = [plantcad_sweep[i : i + batch_size] for i in range(0, len(plantcad_sweep), batch_size)] + batch_index: int | None = int(os.environ["SWEEP_BATCH_INDEX"]) if "SWEEP_BATCH_INDEX" in os.environ else None + if batch_index is not None: + logger.info(f"SWEEP_BATCH_INDEX={batch_index}; running batch {batch_index + 1} of {len(batches)}") + batches = [batches[int(batch_index)]] + for i, batch in enumerate(batches): + logger.info(f"Running batch {i + 1}/{len(batches)} with {len(batch)} sweep steps") + if os.environ.get("DRY_RUN"): + logger.info("DRY RUN; skipping execution") + continue + executor_main(steps=batch) + + +if __name__ == "__main__": + main() diff --git a/experiments/plantcad/exp2101_text_multi_isoflop_sweep.py b/experiments/plantcad/exp2101_text_multi_isoflop_sweep.py new file mode 100644 index 0000000000..6050a02f37 --- /dev/null +++ b/experiments/plantcad/exp2101_text_multi_isoflop_sweep.py @@ -0,0 +1,795 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generate ISOFlop sweep steps for model sizes, architectures, and epochs on pre-tokenized text.""" + +import os +import math +import logging +import dataclasses +import numpy as np +import pandas as pd +from collections.abc import Iterator, Sequence +from dataclasses import dataclass, replace + +from levanter.data.text import LMMixtureDatasetConfig, UrlDatasetSourceConfig +from levanter.models.llama import LlamaConfig +from levanter.models.qwen import Qwen3Config +from levanter.optim.cautious import CautiousConfig +from levanter.optim.config import OptimizerConfig +from levanter.utils.flop_utils import lm_flops_per_token +from levanter.layers.rotary import Llama3RotaryEmbeddingsConfig + +from experiments.defaults import default_train +from experiments.evals.task_configs import EvalTaskConfig +from experiments.llama import llama3_tokenizer, llama3_tokenizer_vocab_size +from experiments.simple_train_config import SimpleTrainConfig +from marin.execution.executor import ExecutorStep, executor_main +from fray.cluster import ResourceConfig + +logger = logging.getLogger("ray") + +# TPU v5p hardware constants for memory estimation +# Constants for TPU v5p +HBM_PER_CHIP_GIB = 95 +CORES_PER_CHIP = 2 +V5P_CORE_OPTIONS = [8, 16, 32, 128, 256, 512] # TPU slices + +ModelConfig = LlamaConfig | Qwen3Config + +# Maximum number of configs per (budget, architecture) combination before downsampling +MAX_SWEEP_CONFIGS = 10 + + +def simulated_epoch_train( + name: str, + tokenized: LMMixtureDatasetConfig, + model_config: ModelConfig, + train_config: "SimpleTrainConfig", + train_tokens: int, + dataset_tokens: int, + epoch_count: int = 1, + tags: Sequence[str] = (), + eval_harness_tasks: Sequence[EvalTaskConfig] = (), +) -> ExecutorStep: + """Train with simulated epoching. When epoch_count=1, uses full dataset.""" + if not isinstance(epoch_count, int) or epoch_count < 1: + raise ValueError(f"epoch_count must be int >= 1, got {epoch_count}") + + pretraining_data = tokenized + + if epoch_count == 1: + return default_train( + name, + tokenized=pretraining_data, + model_config=model_config, + train_config=train_config, + tags=tags, + eval_harness_tasks=eval_harness_tasks, + use_default_validation=False, + ) + + # To use simulated epoching in Levanter, we need to first address the fact that + # we are already limiting training to less than 1 true epoch in each run. + # + # The Levanter formula for this feature takes two arguments, experiment_budget and target_budget, + # and then uses this formula to determine how to slice each epoch: + # + # simulated_data_ratio = experiment_budget / target_budget + # simulated_length_of_dataset = int(true_length_of_dataset * simulated_data_ratio) + # sliced_datasets[name] = ds.slice_dataset(end_index=simulated_length_of_dataset) + # + # See: https://github.com/marin-community/marin/blob/eb4acbdd185a34202da16052c46c74eb570e69a5/lib/levanter/src/levanter/data/text.py#L1273-L1280 + # + # This means that `simulated_data_ratio` must become equal to `train_tokens / dataset_tokens / epoch_count` + # in order for the simulated epochs to work on top of a partial epoch. + # We accomplish this here by setting: + # - experiment_budget = train_tokens + # - target_budget = dataset_tokens * epoch_count + experiment_budget, target_budget = train_tokens, dataset_tokens * epoch_count + simulated_pretraining_data = dataclasses.replace( + pretraining_data, target_budget=target_budget, experiment_budget=experiment_budget + ) + + return default_train( + name, + tokenized=simulated_pretraining_data, + model_config=model_config, + train_config=train_config, + tags=tags, + eval_harness_tasks=eval_harness_tasks, + use_default_validation=False, + ) + + +def format_num(n: int | float) -> str: + """Format numbers in T/B/M/K notation (e.g., 1.5T, 100B, 5.2M, 1.0K).""" + if n >= 1_000_000_000_000: + return f"{n / 1_000_000_000_000:.1f}T" + elif n >= 1_000_000_000: + return f"{n / 1_000_000_000:.1f}B" + elif n >= 10_000_000: + return f"{int(n / 1_000_000)}M" + elif n >= 1_000_000: + return f"{n / 1_000_000:.1f}M" + elif n >= 1_000: + return f"{n / 1_000:.1f}K" + return str(int(n)) + + +@dataclass(frozen=True) +class IsoFlopDataConfig: + # Use pretokenized DCLM baseline from GCS (control experiment) + # Absolute path in same region used to avoid re-downloading when running with a different MARIN_PREFIX + # (i.e. marin-dna-us-central1) + tokenized_path: str = "gs://marin-us-central1/tokenized/dclm_baseline-0206f1" + tokenizer: str = llama3_tokenizer + vocab_size: int = llama3_tokenizer_vocab_size + seq_len: int = 4096 + # Keep consistent with PlantCAD total token count even though DCLM has far, far more + total_token_count: int = 2_600_000_000_000 # 2.6T + + +@dataclass(frozen=True) +class IsoFlopSweepParams: + """Configuration for a specific compute range in the ISOFlop sweep.""" + + experiment_name: str + compute_range_name: str + budgets: list[float] + steps_per_run: int + hidden_step_size: int = 128 + hidden_head_ratio: int = 128 + + +# Predefined compute range configurations +ISOFLOP_SWEEPS = { + # low-compute: 1.6e17 to 2e18, steps=16_384 (v2.12) + # "low": IsoFlopSweepParams( + # experiment_name="plantcad_isoflop_v2.12", + # compute_range_name="low", + # budgets=list(np.logspace(np.log10(1.6e17), np.log10(2e18), 5)), + # steps_per_run=16_384, + # hidden_step_size=128, + # ), + # # mid-compute: 3.2e17 to 4e18, steps=32_768 (v2.13) + "mid": IsoFlopSweepParams( + experiment_name="plantcad_isoflop_v2.13", + compute_range_name="mid", + budgets=list(np.logspace(np.log10(3.2e17), np.log10(4e18), 5)), + steps_per_run=32_768, + ), + # # high-compute: 6.4e17 to 8e18, steps=65_536 (v2.9) + # "high": IsoFlopSweepParams( + # experiment_name="plantcad_isoflop_v2.9", + # compute_range_name="high", + # # i in [0, 1, 4] + # budgets=list(np.logspace(np.log10(6.4e17), np.log10(8e18), 5)), + # steps_per_run=65_536, + # ), +} + + +@dataclass(frozen=True) +class IsoFlopSweepConfig: + tokenized_dataset: LMMixtureDatasetConfig + vocab_size: int + seq_len: int + total_token_count: int + experiment_name: str + compute_range_name: str + budgets: list[float] + steps_per_run: int + hidden_step_size: int + hidden_head_ratio: int + + epochs: list[int] = dataclasses.field(default_factory=lambda: [1]) + min_hidden_pow: int = 9 + max_hidden_pow: int = 12 + mlp_ratio: int = 4 + base_hidden_layer_ratio: int = 64 + lr_max: float | None = 0.03 + flop_tolerance: float = 0.01 + architectures: list[str] = dataclasses.field(default_factory=lambda: ["qwen"]) + # TODO: adjust eval example count to account for num tpus in the even that v5p-8 is not used + per_device_eval_parallelism: int = 8 + max_eval_batches: int = 1024 + num_evals: int = 3 + lr_constant: float = 0.33 + base_optimizer_config: OptimizerConfig = dataclasses.field( + default_factory=lambda: CautiousConfig( + learning_rate=1.0, + weight_decay=0.1, + min_lr_ratio=0.0, + warmup=0.1, + beta1=0.95, + beta2=0.98, + epsilon=1e-15, + max_grad_norm=1, + adamc_weight_decay=True, + lr_schedule="linear", + decay=0.2, + ), + ) + base_train_config: SimpleTrainConfig = dataclasses.field( + default_factory=lambda: SimpleTrainConfig( + resources=ResourceConfig.with_tpu("v5p-8"), + train_batch_size=1, + num_train_steps=50_000, + learning_rate=1.0, + weight_decay=0.1, + min_lr_ratio=0.0, + lr_schedule="linear", + decay=0.2, + ) + ) + + +def estimate_bytes( + param_count: int, + hidden_dim: int, + num_layers: int, + batch: int, + seq_len: int, + vocab: int, + optim_mult: int = 3, + dtype_size: int = 4, + # TODO: determine why this needs to be higher than the original 2 to avoid OOMs on DCLM + fudge_factor: float = 4, +) -> int: + """Estimate memory usage (in bytes) for one training step.""" + param_bytes = param_count * optim_mult * dtype_size + + act_bytes = (batch * seq_len) * ((hidden_dim * num_layers) + vocab * fudge_factor) + + total_bytes = param_bytes + act_bytes + return int(total_bytes) * fudge_factor + + +def pick_v5p_type( + config: ModelConfig, + hidden: int, + layers: int, + batch: int, + seq_len: int, + vocab: int, +) -> str: + """Select the smallest TPU v5p slice that fits the model in float32.""" + _, _, param_count = compute_param_count(config, vocab) + need_bytes = estimate_bytes(param_count, hidden, layers, batch, seq_len, vocab) + chip_bytes = HBM_PER_CHIP_GIB * 1024**3 + chips = math.ceil(need_bytes / chip_bytes) + cores_req = chips * CORES_PER_CHIP + + valid = [c for c in V5P_CORE_OPTIONS if c >= cores_req] + if not valid: + raise ValueError(f"Model too large for available v5p slices (need {cores_req} cores).") + + return f"v5p-{min(valid)}" + + +def round_to_power_of_two(x: float) -> int: + """Round ``x`` to the nearest power of two.""" + if x <= 1: + return 1 + return 2 ** math.ceil(math.log2(x)) + + +def compute_param_count(config: LlamaConfig, vocab_size: int) -> tuple[int, int, int]: + # Copied from compute_num_parameters in experiments/llama.py and modified to + # return multiple results; see: + # https://github.com/marin-community/marin/blob/bc58ab8ee62ba5e38ce4f1e2d7d64271431be160/experiments/llama.py#L249-L267 + head_size = config.hidden_dim // config.num_heads + q_params = config.num_heads * head_size * config.hidden_dim + k_params = config.num_kv_heads * head_size * config.hidden_dim + v_params = config.num_kv_heads * head_size * config.hidden_dim + o_params = config.num_heads * head_size * config.hidden_dim + attention_params = q_params + k_params + v_params + o_params + + layer_norm_params = 2 * config.hidden_dim + + gate_params = config.hidden_dim * config.intermediate_dim + up_params = config.hidden_dim * config.intermediate_dim + down_params = config.intermediate_dim * config.hidden_dim + mlp_params = gate_params + up_params + down_params + + nonembedding_params = config.num_layers * (attention_params + mlp_params + layer_norm_params) + embedding_params = 2 * vocab_size * config.hidden_dim + + return embedding_params, nonembedding_params, nonembedding_params + embedding_params + + +def compute_total_flops( + batch: int, + num_layers: int, + hidden: int, + intermediate: int, + num_kv_heads: int, + num_heads: int, + steps: int, + seq_len: int, + vocab_size: int, +) -> float: + """Compute total training FLOPs using Levanter utilities.""" + + flops_per_token = lm_flops_per_token( + hidden, + intermediate, + num_layers, + num_kv_heads, + num_heads, + seq_len, + vocab_size, + glu=True, + ) + return flops_per_token * batch * steps * seq_len + + +@dataclass +class IsoFlopRunConfig: + experiment_name: str + compute_range_name: str + steps_per_run: int + hidden_step_size: int + architecture: str + hidden_size: int + intermediate_dim: int + num_layers: int + n_heads: int + n_kv_heads: int + batch_size: int + batch_target: int # Original batch size before LR-based halving + train_steps: int + lr: float + beta2: float + budget: float + steps_per_eval: int + train_tokens: int + dataset_tokens: int + num_params: int + embed_params: int + epoch_count: int + model_config: ModelConfig + + +def generate_run_configs(cfg: IsoFlopSweepConfig, budget: float) -> Iterator[IsoFlopRunConfig]: + """Generate ISOFlop run configurations within the FLOP budget.""" + + dataset_tokens = cfg.total_token_count + + # Loop over architecture as the primary dimension of the search space + for architecture in cfg.architectures: + + # Loop through hidden size on a grid, which will determine the model + # size and therefore token count for each run config + for hidden_size in range(2**cfg.min_hidden_pow, (2**cfg.max_hidden_pow) + 1, cfg.hidden_step_size): + hs_pow = math.log2(hidden_size) + intermediate_dim = hidden_size * cfg.mlp_ratio + num_layers = round(hidden_size / (cfg.base_hidden_layer_ratio + (hs_pow * 4) - cfg.min_hidden_pow)) + assert ( + hidden_size % cfg.hidden_head_ratio == 0 + ), f"hidden_size ({hidden_size}) must be evenly divisible by hidden_head_ratio ({cfg.hidden_head_ratio})" + n_heads = max(1, hidden_size // cfg.hidden_head_ratio) + n_kv_heads = n_heads + + # Calculate batch size to meet budget with fixed steps + batch_exact_val = budget / compute_total_flops( + batch=1, + num_layers=num_layers, + hidden=hidden_size, + intermediate=intermediate_dim, + num_kv_heads=n_kv_heads, + num_heads=n_heads, + steps=cfg.steps_per_run, + seq_len=cfg.seq_len, + vocab_size=cfg.vocab_size, + ) + + batch_size = round_to_power_of_two(batch_exact_val) + batch_target = batch_size # Store original before LR-based halving + + # Scale LR with sqrt(batch) and hidden size + # Reference: https://arxiv.org/pdf/2203.03466 (Section 10 Related Works) + lr = (cfg.lr_constant * math.sqrt(batch_size)) / hidden_size + + # Halve batch size until LR is stable + if cfg.lr_max is not None: + while lr > cfg.lr_max: + logger.warning( + f"Halving batch size for ({architecture=}, {hidden_size=}): " + f"{batch_size} -> {batch_size // 2} (lr={lr:.4f}, lr_max={cfg.lr_max})" + ) + batch_size //= 2 + lr = (cfg.lr_constant * math.sqrt(batch_size)) / hidden_size + + # Set beta2 based on https://arxiv.org/pdf/2507.07101 + b2 = 0.98 ** (batch_size / 128) + + if batch_size < 8: + logger.warning( + f"Skipping config for ({architecture=}, {hidden_size=}) " + f"with batch size {batch_size} (less than 8)" + ) + continue + + # Recompute exact steps based on adjusted batch size + steps_exact = budget / compute_total_flops( + batch=batch_size, + num_layers=num_layers, + hidden=hidden_size, + intermediate=intermediate_dim, + num_kv_heads=n_kv_heads, + num_heads=n_heads, + steps=1, + seq_len=cfg.seq_len, + vocab_size=cfg.vocab_size, + ) + train_steps = round(steps_exact) + + # Ensure actual flops still within range + achieved_flops = compute_total_flops( + batch=batch_size, + num_layers=num_layers, + hidden=hidden_size, + intermediate=intermediate_dim, + num_kv_heads=n_kv_heads, + num_heads=n_heads, + steps=train_steps, + seq_len=cfg.seq_len, + vocab_size=cfg.vocab_size, + ) + if abs(achieved_flops - budget) / budget > cfg.flop_tolerance: + logger.warning( + f"Skipping config for ({architecture=}, {hidden_size=}) with achieved flops {achieved_flops} " + f"(not within {cfg.flop_tolerance} of budget {budget})" + ) + continue + + train_tokens = train_steps * batch_size * cfg.seq_len + # Subtract 1 from num_evals to account for the first evaluation + num_evals = max(1, cfg.num_evals - 1) + steps_per_eval = max(1, train_steps // num_evals) + + if train_tokens > dataset_tokens: + logger.warning( + f"Skipping config for ({architecture=}, {hidden_size=}) with train tokens {train_tokens} " + f"(greater than dataset tokens {dataset_tokens})" + ) + continue + + if architecture == "llama": + model_cfg = LlamaConfig( + max_seq_len=cfg.seq_len, + hidden_dim=hidden_size, + intermediate_dim=intermediate_dim, + num_heads=n_heads, + num_kv_heads=n_kv_heads, + num_layers=num_layers, + ) + elif architecture == "qwen": + model_cfg = Qwen3Config( + max_seq_len=cfg.seq_len, + hidden_dim=hidden_size, + intermediate_dim=intermediate_dim, + num_heads=n_heads, + num_kv_heads=n_kv_heads, + num_layers=num_layers, + rope=Llama3RotaryEmbeddingsConfig(), + ) + else: + raise ValueError(f"Unknown architecture: {architecture}") + + embed_params, _, num_params = compute_param_count(model_cfg, cfg.vocab_size) + + for epoch_count in cfg.epochs: + yield IsoFlopRunConfig( + experiment_name=cfg.experiment_name, + compute_range_name=cfg.compute_range_name, + steps_per_run=cfg.steps_per_run, + hidden_step_size=cfg.hidden_step_size, + architecture=architecture, + hidden_size=hidden_size, + intermediate_dim=intermediate_dim, + num_layers=num_layers, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + batch_size=batch_size, + batch_target=batch_target, + train_steps=train_steps, + lr=lr, + beta2=b2, + budget=budget, + steps_per_eval=steps_per_eval, + train_tokens=train_tokens, + dataset_tokens=dataset_tokens, + num_params=num_params, + embed_params=embed_params, + epoch_count=epoch_count, + model_config=model_cfg, + ) + + +def downsample_configs(configs: list[IsoFlopRunConfig], max_configs: int = MAX_SWEEP_CONFIGS) -> list[IsoFlopRunConfig]: + """Downsample configs to at most max_configs, keeping first and last, evenly sampling middle. + + Configs are assumed to be sorted by hidden_size (ascending). The downsampling preserves + the first and last configs (smallest and largest hidden_size) and takes every Nth config + from the middle, where N is chosen to keep the total count <= max_configs. + + Args: + configs: List of IsoFlopRunConfig (assumed sorted by hidden_size ascending) + max_configs: Maximum number of configs to return + + Returns: + Downsampled list with at most max_configs configs, preserving first and last + """ + if len(configs) <= max_configs: + return configs + + if max_configs < 2: + raise ValueError("max_configs must be at least 2 to keep first and last") + + # Keep first and last, downsample middle + first = configs[0] + last = configs[-1] + middle = configs[1:-1] + + # Calculate step size N to get at most (max_configs - 2) middle elements + # N is the minimum step that keeps len(downsampled_middle) <= max_middle + max_middle = max_configs - 2 + n = math.ceil(len(middle) / max_middle) + + # Take every Nth element from middle + downsampled_middle = middle[::n] + + return [first, *downsampled_middle, last] + + +def _log_isoflop_run_configs(all_configs: list[IsoFlopRunConfig]): + """Log summary of generated ISOFlop configurations.""" + if all_configs: + df = pd.DataFrame([dataclasses.asdict(c) for c in all_configs]) + + # Format large numbers for readability + if "num_params" in df.columns: + df["num_params"] = df["num_params"].apply(format_num) + if "embed_params" in df.columns: + df["embed_params"] = df["embed_params"].apply(format_num) + if "train_tokens" in df.columns: + df["train_tokens"] = df["train_tokens"].apply(format_num) + + pd.set_option("display.max_rows", None) + pd.set_option("display.max_columns", None) + pd.set_option("display.width", 1000) + + logger.info("\n" + "=" * 80) + logger.info("Configuration Summary Dataframe") + logger.info("=" * 80) + logger.info( + "\n" + + str(df.drop(columns=["model_config", "intermediate_dim", "n_kv_heads", "dataset_tokens", "steps_per_run"])) + ) + + logger.info("\n" + "=" * 50) + logger.info("Configs per Budget") + logger.info("=" * 50) + logger.info("\n" + str(df.groupby("budget").size())) + + logger.info("\n" + "=" * 50) + logger.info("Configs per Architecture") + logger.info("=" * 50) + logger.info("\n" + str(df.groupby("architecture").size())) + + # Create table of unique param count mappings + logger.info("\n" + "=" * 50) + logger.info("Param Count Mapping") + logger.info("=" * 50) + raw_df = pd.DataFrame([dataclasses.asdict(c) for c in all_configs]) + param_mapping = raw_df[["num_params", "embed_params"]].drop_duplicates() + param_mapping["nonembedding_params"] = param_mapping["num_params"] - param_mapping["embed_params"] + param_mapping["pct_embedding_params"] = ( + param_mapping["embed_params"] / param_mapping["num_params"] * 100 + ).round(2) + param_mapping = param_mapping[ + ["num_params", "embed_params", "nonembedding_params", "pct_embedding_params"] + ].sort_values("num_params") + param_mapping = param_mapping.rename( + columns={ + "num_params": "Total Params", + "embed_params": "Embedding Params", + "nonembedding_params": "Non-Embedding Params", + "pct_embedding_params": "% Embedding", + } + ) + logger.info("\n" + param_mapping.to_string(index=False)) + + logger.info("=" * 50 + "\n") + else: + logger.warning("No configurations generated!") + + +def generate_isoflop_steps(config: IsoFlopSweepConfig) -> list[ExecutorStep]: + """Generate executor steps for an ISOFlop sweep.""" + + # Collect all run configs first, downsampling per (budget, architecture) combination + all_configs: list[IsoFlopRunConfig] = [] + + for budget in config.budgets: + # Collect configs for this budget, grouped by architecture + budget_configs = list(generate_run_configs(config, budget)) + + # Group by architecture + configs_by_arch: dict[str, list[IsoFlopRunConfig]] = {} + for c in budget_configs: + if c.architecture not in configs_by_arch: + configs_by_arch[c.architecture] = [] + configs_by_arch[c.architecture].append(c) + + # Downsample each architecture group and add to all_configs + for arch, arch_configs in configs_by_arch.items(): + downsampled = downsample_configs(arch_configs) + if len(arch_configs) > len(downsampled): + logger.info( + f"Downsampled {len(arch_configs)} -> {len(downsampled)} configs " + f"for budget={budget:.1e}, architecture={arch}" + ) + all_configs.extend(downsampled) + + _log_isoflop_run_configs(all_configs) + + # Generate executor steps from validated configs + steps: list[ExecutorStep] = [] + for c in all_configs: + # Use the pre-computed model config + model_cfg = c.model_config + + tpu_type = pick_v5p_type( + config=model_cfg, + hidden=c.hidden_size, + layers=c.num_layers, + batch=c.batch_size, + seq_len=config.seq_len, + vocab=config.vocab_size, + ) + optimizer_cfg = replace(config.base_optimizer_config, learning_rate=c.lr, beta2=c.beta2) + train_cfg = replace( + config.base_train_config, + train_batch_size=c.batch_size, + learning_rate=c.lr, + num_train_steps=c.train_steps, + steps_per_eval=c.steps_per_eval, + per_device_eval_parallelism=config.per_device_eval_parallelism, + max_eval_batches=config.max_eval_batches, + resources=ResourceConfig.with_tpu(tpu_type), + optimizer_config=optimizer_cfg, + ) + + param_count = c.num_params + step = simulated_epoch_train( + name="-".join( + [ + config.experiment_name, + f"A_{c.architecture}", + f"F{c.budget:.1e}", + f"P{format_num(param_count)}", + f"T{format_num(c.train_tokens)}", + f"E{c.epoch_count}", + ] + ), + tokenized=config.tokenized_dataset, + model_config=model_cfg, + train_config=train_cfg, + train_tokens=c.train_tokens, + dataset_tokens=c.dataset_tokens, + epoch_count=c.epoch_count, + eval_harness_tasks=[], + tags=( + f"architecture={c.architecture}", + f"flops_budget={c.budget:.1e}", + f"hidden_size={c.hidden_size}", + f"num_layers={c.num_layers}", + f"batch_size={c.batch_size}", + f"steps={c.train_steps}", + f"tpu={tpu_type}", + f"params={param_count}", + f"params_embed={c.embed_params}", + f"params_nonembed={c.num_params - c.embed_params}", + f"tokens={c.train_tokens}", + f"epochs={c.epoch_count}", + ), + ) + steps.append(step) + + return steps + + +def generate_isoflop_sweeps( + tokenized: LMMixtureDatasetConfig, + vocab_size: int, + seq_len: int, + total_token_count: int, +) -> list[ExecutorStep]: + """Generate executor steps for all ISOFlop sweeps.""" + all_steps: list[ExecutorStep] = [] + for sweep_params in ISOFLOP_SWEEPS.values(): + sweep_cfg = IsoFlopSweepConfig( + tokenized_dataset=tokenized, + vocab_size=vocab_size, + seq_len=seq_len, + total_token_count=total_token_count, + experiment_name=sweep_params.experiment_name, + compute_range_name=sweep_params.compute_range_name, + budgets=sweep_params.budgets, + steps_per_run=sweep_params.steps_per_run, + hidden_step_size=sweep_params.hidden_step_size, + hidden_head_ratio=sweep_params.hidden_head_ratio, + ) + steps = generate_isoflop_steps(sweep_cfg) + all_steps.extend(steps) + return all_steps + + +def get_data_config() -> tuple[LMMixtureDatasetConfig, IsoFlopDataConfig]: + """Use pretokenized DCLM baseline dataset from GCS. + + Returns: + A tuple of (LMMixtureDatasetConfig, IsoFlopDataConfig with dataset metadata). + """ + data_config = IsoFlopDataConfig() + # Create LMMixtureDatasetConfig directly with absolute GCS path (no download needed) + mixture_config = LMMixtureDatasetConfig( + configs={ + "dclm_baseline": UrlDatasetSourceConfig( + cache_dir=data_config.tokenized_path, + ) + }, + train_weights={"dclm_baseline": 1.0}, + tokenizer=data_config.tokenizer, + # TODO: reduce this after ruling out eval noise (back to ~1024) + num_validation_sequences={"dclm_baseline": 131_072}, + ) + return mixture_config, data_config + + +def main(): + mixture_config, data_config = get_data_config() + + # Generate sweep steps + plantcad_sweep = generate_isoflop_sweeps( + tokenized=mixture_config, + vocab_size=data_config.vocab_size, + seq_len=data_config.seq_len, + total_token_count=data_config.total_token_count, + ) + + # Execute in batches of 16 sweep steps to avoid head node OOM errors. + # See: https://discord.com/channels/1354881461060243556/1442689455554171071/1447914001957785620 + # TODO: Figure out how to run on compute nodes instead w/o reverting back to this PR: + # https://discord.com/channels/1354881461060243556/1442689455554171071/1447920947402375291 + batch_size = 16 + batches = [plantcad_sweep[i : i + batch_size] for i in range(0, len(plantcad_sweep), batch_size)] + batch_index: int | None = int(os.environ["SWEEP_BATCH_INDEX"]) if "SWEEP_BATCH_INDEX" in os.environ else None + if batch_index is not None: + logger.info(f"SWEEP_BATCH_INDEX={batch_index}; running batch {batch_index + 1} of {len(batches)}") + batches = [batches[int(batch_index)]] + for i, batch in enumerate(batches): + logger.info(f"Running batch {i + 1}/{len(batches)} with {len(batch)} sweep steps") + if os.environ.get("DRY_RUN"): + logger.info("DRY RUN; skipping execution") + continue + executor_main(steps=batch) + + +if __name__ == "__main__": + main() diff --git a/lib/levanter/src/levanter/data/text.py b/lib/levanter/src/levanter/data/text.py index b0d5c59f07..1df43a0031 100644 --- a/lib/levanter/src/levanter/data/text.py +++ b/lib/levanter/src/levanter/data/text.py @@ -1022,6 +1022,12 @@ def shuffle_ds(ds, key): # Note(Will): This blocks on datasets being fully processed even for small simulated runs making simulating data size slightly latency inducing but I think that's ok true_length_of_dataset = len(ds.as_sync_dataset()) simulated_length_of_dataset = int(true_length_of_dataset * simulated_data_ratio) + logger.info( + f"Slicing simulated dataset {name} from {true_length_of_dataset} examples to {simulated_length_of_dataset} examples; " + f"experiment_budget: {self.experiment_budget}, " + f"target_budget: {self.target_budget}, " + f"simulated_data_ratio: {simulated_data_ratio}" + ) sliced_datasets[name] = ds.slice_dataset(end_index=simulated_length_of_dataset) datasets = sliced_datasets diff --git a/lib/levanter/src/levanter/main/train_lm.py b/lib/levanter/src/levanter/main/train_lm.py index 68f3228e7b..27aa117409 100644 --- a/lib/levanter/src/levanter/main/train_lm.py +++ b/lib/levanter/src/levanter/main/train_lm.py @@ -199,6 +199,11 @@ def loss_function(model: LmHeadModel, example: LmExample, *, key=None): if len(tagged_eval_datasets) == 0: logger.warning("No evaluation datasets provided.") else: + # TODO: remove + logger.info(f"[run={config.trainer.id}] train_config: {config.trainer}") + logger.info(f"[run={config.trainer.id}] eval_batch_size: {config.trainer.eval_batch_size}") + # TODO: /remove + cb = levanter.eval.cb_tagged_lm_evaluate( EvalBatch, tagged_eval_datasets, diff --git a/lib/marin/src/marin/run/ray_run.py b/lib/marin/src/marin/run/ray_run.py index 825c891411..580a09c184 100644 --- a/lib/marin/src/marin/run/ray_run.py +++ b/lib/marin/src/marin/run/ray_run.py @@ -21,11 +21,12 @@ import re import shlex import subprocess +import sys import time from pathlib import Path import yaml -from ray.job_submission import JobSubmissionClient +from ray.job_submission import JobStatus, JobSubmissionClient from marin.cluster.config import find_config_by_region from fray.cluster.ray import DashboardConfig, ray_dashboard @@ -120,8 +121,12 @@ async def submit_and_track_job( entrypoint_num_gpus: float | None = None, entrypoint_memory: int | None = None, entrypoint_resources: dict | None = None, -): - """Submit a job to Ray and optionally track logs.""" +) -> JobStatus | None: + """Submit a job to Ray and optionally track logs. + + Returns: + The final JobStatus if waiting for the job, None if no_wait is True. + """ client = make_client() current_dir = os.getcwd() @@ -170,12 +175,17 @@ async def submit_and_track_job( logger.info(f"Job URL: {client.get_address()}/#/jobs/{submission_id}") if no_wait: - return + return None # Stream logs asynchronously async for lines in client.tail_job_logs(submission_id): print(lines, end="") + # Check final job status after log streaming completes + final_status = client.get_job_status(submission_id) + logger.info(f"Job {submission_id} finished with status: {final_status}") + return final_status + def main(): """Parse command-line arguments and submit the job.""" @@ -322,9 +332,9 @@ def main(): else: submission_id = generate_submission_id(full_cmd) - async def run_job(): + async def run_job() -> JobStatus | None: try: - await submit_and_track_job( + return await submit_and_track_job( full_cmd, args.extra, env_vars, @@ -336,22 +346,24 @@ async def run_job(): entrypoint_resources=entrypoint_resources, ) except KeyboardInterrupt: - pass + return None except asyncio.CancelledError: logger.info("Job tracking cancelled by user.") - pass + return None except Exception as e: logger.error(f"Error submitting or tracking job: {e}") raise + final_status: JobStatus | None = None try: if cluster_config: with ray_dashboard(DashboardConfig.from_cluster(cluster_config)): - asyncio.run(run_job()) + final_status = asyncio.run(run_job()) else: - asyncio.run(run_job()) + final_status = asyncio.run(run_job()) except Exception: logger.error("Failed to run job", exc_info=True) + sys.exit(1) finally: if args.auto_stop: logger.info(f"Auto-stopping job {submission_id}...") @@ -364,6 +376,23 @@ async def run_job(): client = make_client() client.stop_job(submission_id) + # Exit with appropriate code based on job status + if final_status is None: + # no_wait mode or interrupted - exit successfully + sys.exit(0) + elif final_status == JobStatus.SUCCEEDED: + sys.exit(0) + elif final_status == JobStatus.FAILED: + logger.error(f"Job {submission_id} failed") + sys.exit(1) + elif final_status == JobStatus.STOPPED: + logger.warning(f"Job {submission_id} was stopped") + sys.exit(2) + else: + # Unexpected status (PENDING, RUNNING shouldn't happen after log streaming) + logger.warning(f"Job {submission_id} ended with unexpected status: {final_status}") + sys.exit(3) + if __name__ == "__main__": logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") diff --git a/scripts/claude_v_ray.sh b/scripts/claude_v_ray.sh new file mode 100755 index 0000000000..05fcc935e5 --- /dev/null +++ b/scripts/claude_v_ray.sh @@ -0,0 +1,92 @@ +#!/usr/bin/env bash +# +# Wrapper script to run ray_run.py with automatic retry on failure. +# +# Usage: +# ./scripts/ray_run_retry.sh [OPTIONS] -- +# +# Options: +# --max-retries N Maximum number of retry attempts (default: 3) +# +# Example: +# ./scripts/ray_run_retry.sh --max-retries 5 -- \ +# --cluster us-central2 -- uv run python experiments/my_exp.py +# +# Exit codes from ray_run.py: +# 0 = SUCCEEDED +# 1 = FAILED +# 2 = STOPPED +# 3 = Unexpected status + +set -euo pipefail + +# Default configuration +MAX_RETRIES=3 + +# Parse wrapper options +while [[ $# -gt 0 ]]; do + case "$1" in + --max-retries) + MAX_RETRIES="$2" + shift 2 + ;; + --) + shift + break + ;; + *) + echo "Unknown option: $1" >&2 + echo "Usage: $0 [--max-retries N] -- " >&2 + exit 1 + ;; + esac +done + +if [[ $# -eq 0 ]]; then + echo "Error: No ray_run.py arguments provided after '--'" >&2 + echo "Usage: $0 [--max-retries N] -- " >&2 + exit 1 +fi + +# Store ray_run arguments +RAY_RUN_ARGS=("$@") + +attempt=0 + +while true; do + attempt=$((attempt + 1)) + echo "==============================================" + echo "Attempt $attempt of $((MAX_RETRIES + 1))" + echo "==============================================" + + set +e + uv run python -m marin.run.ray_run "${RAY_RUN_ARGS[@]}" + exit_code=$? + set -e + + case $exit_code in + 0) + echo "Job succeeded on attempt $attempt" + exit 0 + ;; + 1) + echo "Job failed (exit code 1)" + ;; + 2) + echo "Job was stopped (exit code 2)" + # Don't retry stopped jobs - they were intentionally stopped + echo "Not retrying stopped jobs." + exit 2 + ;; + *) + echo "Job exited with unexpected code: $exit_code" + ;; + esac + + if [[ $attempt -gt $MAX_RETRIES ]]; then + echo "Max retries ($MAX_RETRIES) exceeded. Giving up." + exit $exit_code + fi + + echo "Retrying immediately..." +done