diff --git a/experiments/exp2284_test_remez.py b/experiments/exp2284_test_remez.py new file mode 100644 index 0000000000..56c31ba16e --- /dev/null +++ b/experiments/exp2284_test_remez.py @@ -0,0 +1,100 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Learning rate sweeps comparing MuP and non-MuP variants of small Llama models.""" + +import logging + +from experiments.defaults import default_train +from experiments.pretraining_datasets.dclm import dclm_mixture_config_llama3 +from experiments.simple_train_config import SimpleTrainConfig +from fray.cluster import ResourceConfig +from levanter.callbacks.watch import WatchConfig +from levanter.models.llama import LlamaConfig +from levanter.optim import MuonRemezConfig +from marin.execution.executor import ExecutorStep, executor_main +from marin.training.training import TrainLmOnPodConfig + +logger = logging.getLogger("ray") + +llama_300m = LlamaConfig( + max_seq_len=4096, + hidden_dim=768, + intermediate_dim=3072, + num_heads=12, + num_kv_heads=12, # following the same ratio as the original code + num_layers=32, +) + + +def _format_lr(lr: float) -> str: + return f"{lr:.2e}".replace("e-0", "e-").replace("e+0", "e+").replace("+", "") + + +def _lr_sweep() -> list[ExecutorStep[TrainLmOnPodConfig]]: + """Construct a set of training steps for the provided hyperparameter sweep.""" + + steps = [] + num_train_step = 11520 + weight_decay = 0.1 # Fixed weight decay, no sweeping + lrs = [0.001, 0.002, 0.004, 0.008, 0.012] + for lr in lrs: + optimizer_config = MuonRemezConfig( + learning_rate=lr, + warmup=0, + min_lr_ratio=0.0, + lr_schedule="cosine", + adam_lr=0.4 * lr, + momentum=0.95, + nesterov=True, + backend_steps=7, + weight_decay=weight_decay, + beta1=0.9, + beta2=0.95, + epsilon=1e-8, + muon_epsilon=1e-8, + max_grad_norm=1.0, + use_kimi_scaling=False, + ) + train_config = SimpleTrainConfig( + resources=ResourceConfig.with_tpu("v5litepod-64"), + train_batch_size=128, + num_train_steps=num_train_step, + learning_rate=lr, + watch=WatchConfig(watch_targets=["grads", "params"], interval=10), + optimizer_config=optimizer_config, + ) + step = default_train( + name=f"test-remez-300m-lr{_format_lr(lr)}-wd{weight_decay}-warmup0-alr0.4", + tokenized=dclm_mixture_config_llama3, + model_config=llama_300m, + train_config=train_config, + tags=("remez", "300m", "lr_sweep", "cosine"), + eval_harness_tasks=(), + ) + + steps.append(step) + + return steps + + +def build_steps(): + all_steps = [] + all_steps.extend(_lr_sweep()) + + return all_steps + + +if __name__ == "__main__": + executor_main(steps=build_steps()) diff --git a/lib/levanter/src/levanter/optim/__init__.py b/lib/levanter/src/levanter/optim/__init__.py index a45fcce29e..d645a3358e 100644 --- a/lib/levanter/src/levanter/optim/__init__.py +++ b/lib/levanter/src/levanter/optim/__init__.py @@ -23,6 +23,7 @@ # muon "MuonConfig", "MuonHConfig", + "MuonRemezConfig", "ScaleByMuonState", # rmsprop "RMSPropMomentumConfig", @@ -52,6 +53,7 @@ from .mars import MarsConfig, ScaleByMarsState from .muon import MuonConfig, ScaleByMuonState from .muonh import MuonHConfig +from .muonremez import MuonRemezConfig from .rmsprop import RMSPropMomentumConfig, ScaleByRMSPropMomState from .scion import ScaleByScionState, ScionConfig from .soap import SoapConfig diff --git a/lib/levanter/src/levanter/optim/muonremez.py b/lib/levanter/src/levanter/optim/muonremez.py new file mode 100644 index 0000000000..32efab0d63 --- /dev/null +++ b/lib/levanter/src/levanter/optim/muonremez.py @@ -0,0 +1,268 @@ +# Copyright 2025 The Levanter Authors +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from dataclasses import dataclass +from functools import partial +from typing import NamedTuple, Optional + +import chex +import jax +import jax.numpy as jnp +import optax +from optax import tree_utils as otu + +import haliax +from haliax.nn import Linear + +from levanter.optim.config import OptimizerConfig +from levanter.optim.util import map_flattened_linear_layers +from levanter.utils.jax_utils import leaf_key_paths + + +@OptimizerConfig.register_subclass("muonremez") +@dataclass(frozen=True) +class MuonRemezConfig(OptimizerConfig): + """ + MuonRemez optimizer configuration: Similar to Muon but uses coupled Newton-Schulz iteration + to compute U·Σ^(1/2)·V^T instead of Newton-Schulz orthogonalization (which computes Σ^0). + + Update rule: + v_t = β v_{t-1} + g_t (momentum) + u_t = compute_u_sigma_half_vt(v_t) (matrix square root via coupled NS) + W ← (1 - η*λ) W - η u_t (update with weight decay) + """ + + lr: float = 0.02 + adam_lr: float = 6e-4 # Adam LR + momentum: float = 0.95 + nesterov: bool = True + backend_steps: int = 7 # Number of steps for coupled Newton-Schulz iteration + weight_decay: float = 0.0 + adam_weight_decay: Optional[float] = None + beta1: float = 0.9 + beta2: float = 0.95 + epsilon: float = 1e-8 + muon_epsilon: float = 1e-8 + max_grad_norm: float = 1.0 + use_kimi_scaling: bool = False + + def build(self, num_train_steps): + """ + Creates the optimizer. + """ + learning_rate_schedule = self.lr_scheduler(num_train_steps) + adam_lr_schedule = self.lr_scheduler(num_train_steps, override_lr=self.adam_lr) + + def optimizer(learning_rate, adam_lr): + def muonremez_transform(): + components = [] + components.append( + scale_with_muonremez( + self.momentum, self.nesterov, self.backend_steps, self.muon_epsilon, self.use_kimi_scaling + ) + ) + if self.weight_decay > 0: + components.append(optax.add_decayed_weights(self.weight_decay, self.build_weight_decay_mask())) + components.append(optax.scale(-learning_rate)) + optimizer = optax.chain(*components) + return optimizer + + def adamw_transform(): + components = [] + if self.max_grad_norm: + components.append(optax.clip_by_global_norm(self.max_grad_norm)) + components.append(optax.scale_by_adam(self.beta1, self.beta2, self.epsilon)) + adam_weight_decay = self.adam_weight_decay if self.adam_weight_decay is not None else self.weight_decay + if adam_weight_decay > 0: + components.append(optax.add_decayed_weights(adam_weight_decay, self.build_weight_decay_mask())) + components.append(optax.scale(-adam_lr)) + optimizer = optax.chain(*components) + return optimizer + + transformations = { + "muonremez": muonremez_transform(), + "adamw": adamw_transform(), + } + + return optax.multi_transform( + transformations, partial(self.create_mask, use_kimi_scaling=self.use_kimi_scaling) + ) + + return optax.inject_hyperparams(optimizer)(learning_rate=learning_rate_schedule, adam_lr=adam_lr_schedule) + + def create_mask(self, params, use_kimi_scaling=True): + """ + Creates a mask that labels parameters as 'muonremez' or 'adamw' based on their + dimensionality and module path, using AdamW for Embedding and lm_head parameters. + """ + paths = leaf_key_paths(params) + + def mask_fn(param, path): + path_str = ".".join(path) if isinstance(path, (list, tuple)) else str(path) + if "Embedding" in path_str or "lm_head" in path_str: + return "adamw" + elif isinstance(param, Linear): + # muonremez for linear layers + assert ( + param._out_first or use_kimi_scaling + ) # if we don't use kimi's version of scaling, then we need to assume out_first to ensure we are scaling like Out/In + return dataclasses.replace(param, weight="muonremez", bias="adamw" if param.bias is not None else None) + else: + return "adamw" + + return haliax.tree_util.tree_map(mask_fn, params, paths, is_leaf=lambda x: isinstance(x, Linear)) + + +class ScaleByMuonRemezState(NamedTuple): + """State for the MuonRemez algorithm.""" + + momentum_buffer: optax.Updates + + +def scale_with_muonremez(momentum=0.95, nesterov=True, steps=7, muon_eps=1e-8, use_kimi_scaling=False): + # Convert steps to concrete int at function definition time + steps = int(steps) + + def init_fn(params): + momentum_buffer = otu.tree_zeros_like(params) + return ScaleByMuonRemezState(momentum_buffer=momentum_buffer) + + def update_fn(updates, state, params=None): + buf = state.momentum_buffer + buf = jax.tree.map( + lambda m, g: None if g is None else momentum * m + g, + buf, + updates, + is_leaf=lambda x: x is None, + ) + if nesterov: + updates = jax.tree.map( + lambda m, g: None if g is None else momentum * m + g, + buf, + updates, + is_leaf=lambda x: x is None, + ) + else: + updates = buf + + def transform_linear_layer(layer: haliax.nn.Linear): + assert layer.weight.ndim == 2 + # steps is now a concrete int + array = layer.weight.array + updated_weight_array = compute_u_sigma_half_vt(array, steps=steps, eps=muon_eps) + + if not use_kimi_scaling: + scale = jnp.sqrt( + jnp.maximum(1, updated_weight_array.shape[0] / updated_weight_array.shape[1]) + ) # sqrt(Out/In) + else: + scale = 0.2 * jnp.sqrt(jnp.maximum(updated_weight_array.shape[0], updated_weight_array.shape[1])) + updated_weight_array *= scale + + updated_weight = dataclasses.replace(layer.weight, array=updated_weight_array) + + return dataclasses.replace(layer, weight=updated_weight) # type: ignore + + updates = map_flattened_linear_layers(transform_linear_layer, updates) + + return updates, ScaleByMuonRemezState(momentum_buffer=buf) + + return optax.GradientTransformation(init_fn, update_fn) + + +def get_quintic_coeffs(variant="higham"): + """ + Returns polynomial coefficients (alpha, beta, gamma). + Higham (Standard): Order 5 convergence, very stable. + Muon Aggressive: Optimized for speed (steep slope at 0), theoretically derived for 5 steps. + """ + if variant == "higham": + return 1.875, -1.25, 0.375 + elif variant == "muon_aggressive": + # These are approximate values often used to explain the "Greedy" approach + # Real Muon implementations tune these based on specific constraints. + return 3.4445, -4.7750, 2.0315 + else: + return 1.875, -1.25, 0.375 + + +def coupled_newton_schulz_quintic(A, steps=5, variant="muon_aggressive", eps=1e-8): + """ + Coupled Iteration for Square Root. + Target: Z -> A^(-1/2), Y -> A^(1/2) + + Args: + A: Input matrix (2D array, should be square and positive semidefinite) + steps: Number of Newton-Schulz iterations + variant: Variant of the algorithm ('higham' or 'muon_aggressive') + eps: Small constant for numerical stability + + Returns: + Y: A^(1/2) (matrix square root) + Z: A^(-1/2) (inverse matrix square root) + """ + chex.assert_rank(A, 2) + # A should be square for this algorithm (guaranteed by construction in compute_u_sigma_half_vt) + + alpha, beta, gamma = get_quintic_coeffs(variant) + + # Normalization (Crucial) + norm_A = jnp.linalg.norm(A) + eps + A_scaled = A / norm_A + + Y = A_scaled + Z = jnp.eye(A.shape[0], dtype=A.dtype) + I = jnp.eye(A.shape[0], dtype=A.dtype) + + for k in range(steps): + P = Z @ Y + P2 = P @ P + S = alpha * I + beta * P + gamma * P2 + Y = Y @ S + Z = S @ Z + + sqrt_norm = jnp.sqrt(norm_A) + return Y * sqrt_norm, Z * (1.0 / sqrt_norm) + + +def compute_u_sigma_half_vt(M, steps=7, eps=1e-8, variant="higham"): + """ + Chained Coupled Iteration to compute U·Σ^(1/2)·V^T (matrix square root). + + This computes the square root of a matrix via chained coupled Newton-Schulz iteration, + which is different from the zeroth power (orthogonalization) used in standard Muon. + + For a matrix G, this computes an approximation to U·Σ^(1/2)·V^T where G = U·Σ·V^T is the SVD. + + Args: + M: Input matrix (2D array) + steps: Number of Newton-Schulz iterations + eps: Small constant for numerical stability + variant: Variant of the algorithm ('higham' or 'muon_aggressive') + + Returns: + Matrix square root U·Σ^(1/2)·V^T + """ + chex.assert_rank(M, 2) + + # Pre-normalize M + # Check raw norm first (JIT-safe) to handle zero matrix case + norm_M = jnp.linalg.norm(M) + eps + M_scaled = M / norm_M + + if M_scaled.shape[0] >= M_scaled.shape[1]: + # Tall matrix: compute M^T M first + A = M_scaled.T @ M_scaled + _, Z1 = coupled_newton_schulz_quintic(A, steps, variant, eps) + Y2, _ = coupled_newton_schulz_quintic(Z1, steps, variant, eps) + Result_scaled = M_scaled @ Y2 + else: + # Wide matrix: compute M M^T first + A = M_scaled @ M_scaled.T + _, Z1 = coupled_newton_schulz_quintic(A, steps, variant, eps) + Y2, _ = coupled_newton_schulz_quintic(Z1, steps, variant, eps) + Result_scaled = Y2 @ M_scaled + + result = Result_scaled * jnp.sqrt(norm_M) + return result