diff --git a/configs/10B/H100.toml b/configs/10B/H100.toml index d58e6098..d743cc8a 100644 --- a/configs/10B/H100.toml +++ b/configs/10B/H100.toml @@ -11,14 +11,16 @@ sched_type = "wsd-sqrt" batch_size = 128 #1M tokens bs warmup_steps = 1000 total_steps = 1_000_000_000_000 -lr = 7.5e-5 -adam_betas1 = 0.9 -adam_betas2 = 0.95 -weight_decay = 0.1 z_loss = true +[optim.optim] +lr = 7.5e-5 +betas1 = 0.9 +betas2 = 0.95 +weight_decay = 0.1 + [data] seq_length = 8192 dataset_name_or_paths = "/data/datasets/fineweb-edu,/data/datasets/fineweb,/data/datasets/StackV1-popular,/data/datasets/dclm-baseline-1.0-parquet,/data/datasets/open-web-math" diff --git a/configs/10B/H100_cooldown.toml b/configs/10B/H100_cooldown.toml index 9132b1e8..c443e0ed 100644 --- a/configs/10B/H100_cooldown.toml +++ b/configs/10B/H100_cooldown.toml @@ -12,14 +12,15 @@ batch_size = 128 #1M tokens bs warmup_steps = 1000 stable_steps = 74700 total_steps = 90400 -lr = 7.5e-5 - -adam_betas1 = 0.9 -adam_betas2 = 0.95 -weight_decay = 0.1 z_loss = true +[optim.optim] +lr = 7.5e-5 +betas1 = 0.9 +betas2 = 0.95 +weight_decay = 0.1 + [data] seq_length = 8192 dataset_name_or_paths = "/data/datasets/fineweb-edu,/data/datasets/fineweb,/data/datasets/StackV1-popular" diff --git a/configs/13B/H100.toml b/configs/13B/H100.toml index 176a9a12..4bfc3e05 100644 --- a/configs/13B/H100.toml +++ b/configs/13B/H100.toml @@ -9,6 +9,8 @@ ac_ckpt = true batch_size = 1024 #2M tokens bs warmup_steps = 1000 total_steps = 88_000 + +[optim.optim] lr = 3e-4 [data] diff --git a/configs/150M/3090.toml b/configs/150M/3090.toml index 761d1b66..a304abd8 100644 --- a/configs/150M/3090.toml +++ b/configs/150M/3090.toml @@ -10,4 +10,8 @@ reshard_after_forward = true batch_size = 512 warmup_steps = 1000 total_steps = 88_000 -lr = 4e-4 \ No newline at end of file + + +[optim.optim] +lr = 4e-4 + diff --git a/configs/150M/A40.toml b/configs/150M/A40.toml index c82f2df4..ddbef1a5 100644 --- a/configs/150M/A40.toml +++ b/configs/150M/A40.toml @@ -10,4 +10,7 @@ reshard_after_forward = true batch_size = 512 warmup_steps = 1000 total_steps = 88_000 -lr = 4e-4 \ No newline at end of file + +[optim.optim] +lr = 4e-4 + diff --git a/configs/150M/H100.toml b/configs/150M/H100.toml index b15c1750..a6339181 100644 --- a/configs/150M/H100.toml +++ b/configs/150M/H100.toml @@ -10,4 +10,7 @@ reshard_after_forward = true batch_size = 512 warmup_steps = 1000 total_steps = 88_000 -lr = 4e-4 \ No newline at end of file + +[optim.optim] +lr = 4e-4 + diff --git a/configs/150M_short/3090.toml b/configs/150M_short/3090.toml index 4792bc1b..a468b64c 100644 --- a/configs/150M_short/3090.toml +++ b/configs/150M_short/3090.toml @@ -10,4 +10,7 @@ reshard_after_forward = true batch_size = 512 warmup_steps = 500 total_steps = 8192 -lr = 4e-4 \ No newline at end of file + + +[optim.optim] +lr = 4e-4 diff --git a/configs/150M_short/A40.toml b/configs/150M_short/A40.toml index 17aa7aca..80756de5 100644 --- a/configs/150M_short/A40.toml +++ b/configs/150M_short/A40.toml @@ -6,8 +6,12 @@ type_model = "llama2" micro_bs = 32 # change this base on the gpu reshard_after_forward = true + [optim] batch_size = 512 warmup_steps = 500 total_steps = 8192 -lr = 4e-4 \ No newline at end of file + + +[optim.optim] +lr = 4e-4 diff --git a/configs/150M_short/H100.toml b/configs/150M_short/H100.toml index af7582e0..f7a7223d 100644 --- a/configs/150M_short/H100.toml +++ b/configs/150M_short/H100.toml @@ -10,4 +10,7 @@ reshard_after_forward = true batch_size = 512 warmup_steps = 500 total_steps = 8192 -lr = 4e-4 \ No newline at end of file + + +[optim.optim] +lr = 4e-4 diff --git a/configs/1B/H100.toml b/configs/1B/H100.toml index 81cfe3f6..eee5c3ff 100644 --- a/configs/1B/H100.toml +++ b/configs/1B/H100.toml @@ -3,10 +3,13 @@ project = "debug_1B_zero_band" type_model = "llama2" [train] -micro_bs = 16 +micro_bs = 32 +reshard_after_forward = true [optim] -batch_size = 2048 +batch_size = 1024 warmup_steps = 1000 -total_steps = 88_000 -lr = 4e-4 \ No newline at end of file +total_steps = 8192 + +[optim.optim] +lr = 7e-4 \ No newline at end of file diff --git a/configs/1B/H100_c4.toml b/configs/1B/H100_c4.toml deleted file mode 100644 index 1695017d..00000000 --- a/configs/1B/H100_c4.toml +++ /dev/null @@ -1,15 +0,0 @@ -name_model = "1B" -project = "debug_1B_zero_band" -type_model = "llama2" - -[train] -micro_bs = 16 - -[optim] -batch_size = 128 -warmup_steps = 1000 -total_steps = 88_000 -lr = 3e-4 - -[data] -seq_length = 2048 \ No newline at end of file diff --git a/configs/1B/H100_llama2_edu.toml b/configs/1B/H100_llama2_edu.toml deleted file mode 100644 index 31eb3a32..00000000 --- a/configs/1B/H100_llama2_edu.toml +++ /dev/null @@ -1,21 +0,0 @@ -name_model = "1B" -project = "debug_1B_zero_band" -type_model = "llama2" - -[train] -micro_bs = 4 -reshard_after_forward = true - -[data] -seq_length = 8192 -num_workers = 4 -dataset_name_or_paths = "/data/datasets/fineweb-edu" -reverse_data_files = true - -[optim] -batch_size = 256 -warmup_steps = 1000 -total_steps = 1_000_000_000_000 -sched_type = "wsd-sqrt" -lr = 4e-4 -z_loss = true diff --git a/configs/1B/H100_llama2_edu_no_feat.toml b/configs/1B/H100_llama2_edu_no_feat.toml deleted file mode 100644 index 0afd432f..00000000 --- a/configs/1B/H100_llama2_edu_no_feat.toml +++ /dev/null @@ -1,23 +0,0 @@ -name_model = "1B" -project = "debug_1B_zero_band" -type_model = "llama2" - -[train] -micro_bs = 4 -reshard_after_forward = true -attn_fn = "sdpa" -sequence_packing = false - -[data] -seq_length = 8192 -num_workers = 4 -dataset_name_or_paths = "/data/datasets/fineweb-edu" -reverse_data_files = true - -[optim] -batch_size = 256 -warmup_steps = 1000 -total_steps = 1_000_000_000_000 -sched_type = "wsd-sqrt" -lr = 2e-4 -z_loss = false diff --git a/configs/1B/H100_llama3.toml b/configs/1B/H100_llama3.toml deleted file mode 100644 index d4b3ee23..00000000 --- a/configs/1B/H100_llama3.toml +++ /dev/null @@ -1,22 +0,0 @@ -name_model = "1B" -project = "debug_1B_zero_band" -type_model = "llama3" - -[train] -micro_bs = 1 -reshard_after_forward = true - -[data] -seq_length = 8192 -num_workers = 4 -dataset_name_or_paths = "/data/datasets/fineweb-edu" -reverse_data_files = true - -[optim] -batch_size = 256 -warmup_steps = 1000 -total_steps = 1_000_000_000_000 -sched_type = "wsd-sqrt" -lr = 4e-4 -z_loss = true - diff --git a/configs/1B_diloco/H100.toml b/configs/1B_diloco/H100.toml deleted file mode 100644 index 19d3259f..00000000 --- a/configs/1B_diloco/H100.toml +++ /dev/null @@ -1,25 +0,0 @@ -name_model = "1B" -project = "debug_1B_zero_band" -type_model = "llama2" - -[train] -micro_bs = 16 - -[optim] -batch_size = 2048 -warmup_steps = 1000 -total_steps = 88_000 -lr = 4e-4 - -z_loss = true - - -[diloco] -inner_steps = 50 -compression = "uint8" - -[ckpt] -interval = 50 -topk = 3 -path = "outputs_1b_diloco_50" - diff --git a/configs/7B/H100.toml b/configs/7B/H100.toml index f701ef7c..7ea3dc65 100644 --- a/configs/7B/H100.toml +++ b/configs/7B/H100.toml @@ -9,6 +9,8 @@ micro_bs = 1 batch_size = 1024 #2M tokens bs warmup_steps = 1000 total_steps = 88_000 + +[optim.optim] lr = 3e-4 [data] diff --git a/configs/7B_diloco/H100.toml b/configs/7B_diloco/H100.toml index ceeccc43..b6a84d2c 100644 --- a/configs/7B_diloco/H100.toml +++ b/configs/7B_diloco/H100.toml @@ -9,6 +9,8 @@ micro_bs = 1 batch_size = 1024 #2M tokens bs warmup_steps = 1000 total_steps = 88_000 + +[optim.optim] lr = 3e-4 [data] diff --git a/configs/test.toml b/configs/test.toml index 46abc536..d9f9726d 100644 --- a/configs/test.toml +++ b/configs/test.toml @@ -15,4 +15,6 @@ num_workers = 1 batch_size = 128 warmup_steps = 1000 total_steps = 88_000 + +[optim.optim] lr = 4e-4 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index da393919..6020a3f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "pyarrow", "toposolve", "psutil", + "torch-shampoo @ git+https://github.com/facebookresearch/optimizers.git@main", ] [project.optional-dependencies] diff --git a/scripts/simulate_multi_node_diloco.sh b/scripts/simulate_multi_node_diloco.sh index db2e2a17..0b13fe3b 100755 --- a/scripts/simulate_multi_node_diloco.sh +++ b/scripts/simulate_multi_node_diloco.sh @@ -61,7 +61,7 @@ export GLOO_SOCKET_IFNAME=lo for i in $(seq 0 $(($N - 1 ))) do > logs/log$i.log - WANDB_MODE=$([ $i -eq 0 ] && echo "online" || echo "online") GLOBAL_UNIQUE_ID=$i GLOBAL_RANK=$i CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank 0 --rdzv-endpoint localhost:$((BASE_PORT + $i)) --nnodes=1 $@ --data.data_rank $i --data.data_world_size $N > logs/log$i.log 2>&1 & + WANDB_MODE=$([ $i -eq 0 ] && echo "online" || echo "offline") GLOBAL_UNIQUE_ID=$i GLOBAL_RANK=$i CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank 0 --rdzv-endpoint localhost:$((BASE_PORT + $i)) --nnodes=1 $@ --data.data_rank $i --data.data_world_size $N > logs/log$i.log 2>&1 & child_pids+=($!) done diff --git a/src/zeroband/optimizers/__init__.py b/src/zeroband/optimizers/__init__.py new file mode 100644 index 00000000..46f74d3f --- /dev/null +++ b/src/zeroband/optimizers/__init__.py @@ -0,0 +1,59 @@ +from typing import TypeAlias +from pydantic_config import BaseConfig +import torch +from zeroband.optimizers.muon import Muon, AdamConfig, MuonConfig +from distributed_shampoo import EighEigenvalueCorrectionConfig, DistributedShampoo, FullyShardShampooConfig + + +class SoapConfig(BaseConfig): + lr: float = 4e-4 + weight_decay: float = 1e-05 + betas1: float = 0.9 + betas2: float = 0.95 + + max_preconditioner_dim: int = 8192 + precondition_frequency: int = 100 + + +OptimizersConfig: TypeAlias = AdamConfig | MuonConfig | SoapConfig + + +def get_optimizer(params: list[torch.nn.Parameter], config: OptimizersConfig) -> torch.optim.Optimizer: + if isinstance(config, AdamConfig): + return torch.optim.AdamW( + params, + lr=config.lr, + weight_decay=config.weight_decay, + betas=(config.betas1, config.betas2), + ) + elif isinstance(config, MuonConfig): + return Muon( + params, + lr=config.lr, + momentum=config.momentum, + nesterov=config.nesterov, + ns_steps=config.pseudo_order_steps, + adamw_lr=config.adam.lr, + adamw_betas=(config.adam.betas1, config.adam.betas2), + adamw_wd=config.adam.weight_decay, + ) + elif isinstance(config, SoapConfig): + return DistributedShampoo( + params, + lr=config.lr, + betas=(config.betas1, config.betas2), + epsilon=1e-12, + weight_decay=config.weight_decay, + max_preconditioner_dim=config.max_preconditioner_dim, + precondition_frequency=config.precondition_frequency, + use_decoupled_weight_decay=True, + # This can also be set to `QREigenvalueCorrectionConfig` which is less expensive + # and might therefore allow for a smaller `precondition_frequency`. + preconditioner_computation_config=EighEigenvalueCorrectionConfig(), + distributed_config=FullyShardShampooConfig(), + ) + else: + raise ValueError(f"Unknown optimizer {config.optimizer}") + + +__all__ = ["OptimizersConfig", "get_optimizer"] diff --git a/src/zeroband/optimizers/muon.py b/src/zeroband/optimizers/muon.py new file mode 100644 index 00000000..2590f618 --- /dev/null +++ b/src/zeroband/optimizers/muon.py @@ -0,0 +1,224 @@ +# credits to https://github.com/ethansmith2000/fsdp_optimizers/blob/main/muon.py + +from pydantic_config import BaseConfig +import torch +from typing import Generator +from torch.distributed._tensor.api import ( + DTensor, + distribute_tensor, +) # should be move to torch.distributed.tensor with torch 2.5.0 + + +def to_local(x, keep_sharded=False): + if isinstance(x, DTensor): + meta = dict( + device_mesh=x.device_mesh, + placements=x.placements, + shape=x.shape, + stride=x.stride(), + ) + if keep_sharded: + return x.to_local(), meta + else: + return x.full_tensor(), meta + + return x, None + + +def to_dist(x, **meta): + # return DTensor.from_local(x, **meta) + return distribute_tensor(x, device_mesh=meta["device_mesh"], placements=meta["placements"]) + + +# @torch.compile +def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps # ensure top singular value <= 1 + if G.size(0) > G.size(1): + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + if G.size(0) > G.size(1): + X = X.T + return X + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + muon_params, + lr=0.02, + momentum=0.95, + nesterov=True, + ns_steps=6, + adamw_params=None, + adamw_lr=3e-4, + adamw_betas=(0.95, 0.95), + adamw_eps=1e-8, + adamw_wd=0, + ): + defaults = dict( + lr=lr, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_lr_ratio=adamw_lr / lr, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + adamw_wd=adamw_wd, + ) + + # handle list of params or list of dicts + if isinstance(muon_params, Generator): + muon_params = list(muon_params) + if isinstance(adamw_params, Generator): + adamw_params = list(adamw_params) + elif adamw_params is None: + adamw_params = [] + + super().__init__([*muon_params, *adamw_params], defaults) + + # Sort parameters into those for which we will use Muon, and those for which we will not + # we cant pickle booleans for saving, so we will use 1=True, 0=False + def assign_muon(p): + if p.ndim >= 2 and p.size(0) < 10000: + self.state[p]["use_muon"] = 1 + else: + self.state[p]["use_muon"] = 0 + + if isinstance(muon_params[0], dict): + for group in muon_params: + for p in group["params"]: + assign_muon(p) + else: + for p in muon_params: + assign_muon(p) + + def assign_adamw(p): + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = 0 + + if len(adamw_params) and isinstance(adamw_params[0], dict): + for group in adamw_params: + for p in group["params"]: + assign_adamw(p) + else: + for p in adamw_params: + assign_adamw(p) + + if torch.distributed.is_initialized(): + self.world_size = torch.distributed.get_world_size() + self.rank = torch.distributed.get_rank() + else: + self.world_size = 1 + self.rank = 0 + + def step(self): + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + for i, p in enumerate(group["params"]): + if self.state[p]["use_muon"] == 1: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + + meta = None + if isinstance(g, DTensor): + g, meta = to_local(g, keep_sharded=False) + # gives NaNs when done with Dtensor, instead of throwing a typical op not supported error, quite sneaky + g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + if meta is not None: + g = to_dist(g, **meta) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + + g = g.view_as(p.data).type_as(p.data) + p.data.add_(g, alpha=-lr) + else: + # these are all pointwise so we can stay in Dtensor + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - group["adamw_betas"][0]) + buf2.lerp_(g.square(), 1 - group["adamw_betas"][1]) + + g = buf1 / (group["adamw_eps"] + buf2.sqrt()) + + bias_correction1 = 1 - group["adamw_betas"][0] ** step + bias_correction2 = 1 - group["adamw_betas"][1] ** step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * group["adamw_wd"]) + p.data.add_(g, alpha=-lr / scale) + + +class AdamConfig(BaseConfig): + lr: float = 4e-4 + weight_decay: float = 0.1 + betas1: float = 0.9 + betas2: float = 0.95 + + +class MuonConfig(BaseConfig): + pseudo_order_steps: int + lr: float = 0.02 + momentum: float = 0.9 + nesterov: bool = True + + adam: AdamConfig = AdamConfig(lr=3e-4, betas1=0.95, betas2=0.95, weight_decay=0) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 7ab7cb8d..22e6777d 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -21,6 +21,7 @@ from zeroband.comms import ElasticDeviceMesh from zeroband.loss import cross_entropy_max_z_loss from zeroband.models.llama.model import create_block_mask_from_seqlens +from zeroband.optimizers import AdamConfig, OptimizersConfig, get_optimizer from zeroband.utils import ( FakeTokenizer, @@ -43,10 +44,7 @@ class OptimConfig(BaseConfig): - lr: float = 4e-4 - weight_decay: float = 0.1 - adam_betas1: float = 0.9 - adam_betas2: float = 0.95 + optim: OptimizersConfig = AdamConfig() sched_type: Literal["cosine", "linear", "wsd-sqrt"] = "cosine" warmup_steps: int = 1000 @@ -212,12 +210,7 @@ def train(config: Config): logger.debug("model fsdped") # Setup optimizers - inner_optimizer = torch.optim.AdamW( - model.parameters(), - lr=config.optim.lr, - weight_decay=config.optim.weight_decay, - betas=(config.optim.adam_betas1, config.optim.adam_betas2), - ) + inner_optimizer = get_optimizer(model.parameters(), config.optim.optim) if config.diloco is not None: diloco = Diloco(config.diloco, model, elastic_device_mesh) diff --git a/tests/test_torchrun/test_train.py b/tests/test_torchrun/test_train.py index e5703fe3..0cb02958 100644 --- a/tests/test_torchrun/test_train.py +++ b/tests/test_torchrun/test_train.py @@ -112,3 +112,25 @@ def test_packing(packing: bool): num_gpus = [2, 1] packing_arg = "--train.sequence_packing" if packing else "--no-train.sequence_packing" _test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=[packing_arg]) + + +@pytest.mark.parametrize("diloco", [False, True]) +def test_muon(diloco: bool): + num_gpus = [1, 2] if diloco else [2, 1] + _test_multi_gpu( + num_gpus, + "debug/diloco.toml" if diloco else "debug/normal.toml", + extra_args=["--optim.optim.pseudo_order_steps", "6"], + diloco=diloco, + ) + + +@pytest.mark.parametrize("diloco", [False, True]) +def test_soap(diloco: bool): + num_gpus = [1, 2] if diloco else [2, 1] + _test_multi_gpu( + num_gpus, + "debug/diloco.toml" if diloco else "debug/normal.toml", + extra_args=["--optim.optim.precondition_frequency", "1"], + diloco=diloco, + ) diff --git a/uv.lock b/uv.lock index 0b3b3f34..63c6a125 100644 --- a/uv.lock +++ b/uv.lock @@ -1899,6 +1899,14 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/69/72/20cb30f3b39a9face296491a86adb6ff8f1a47a897e4d14667e6cf89d5c3/torch-2.5.1-cp313-cp313-manylinux1_x86_64.whl", hash = "sha256:9b61edf3b4f6e3b0e0adda8b3960266b9009d02b37555971f4d1c8f7a05afed7", size = 906393265 }, ] +[[package]] +name = "torch-shampoo" +version = "1.0.0" +source = { git = "https://github.com/facebookresearch/optimizers.git?rev=main#f9d2a8cb526709bd4b5ef71f8cca3705906a0f94" } +dependencies = [ + { name = "torch" }, +] + [[package]] name = "torchdata" version = "0.8.0" @@ -2197,6 +2205,7 @@ dependencies = [ { name = "setuptools" }, { name = "toposolve" }, { name = "torch" }, + { name = "torch-shampoo" }, { name = "torchdata" }, { name = "transformers" }, { name = "zstandard" }, @@ -2234,6 +2243,7 @@ requires-dist = [ { name = "setuptools" }, { name = "toposolve" }, { name = "torch", specifier = "==2.5.1" }, + { name = "torch-shampoo", git = "https://github.com/facebookresearch/optimizers.git?rev=main" }, { name = "torchdata", specifier = ">=0.8.0" }, { name = "transformers", specifier = ">=4.44.2" }, { name = "wandb", marker = "extra == 'all'" },