diff --git a/README.md b/README.md index b04cd13d..c8d6d172 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ install uv ```bash curl -LsSf https://astral.sh/uv/install.sh | sh -uv sync +uv sync --extra all ``` run your code using @@ -16,3 +16,22 @@ run your code using ```bash uv run ... ``` + +## quick check + +To check that everything is working you can do + +```bash +ZERO_BAND_LOG_LEVEL=DEBUG torchrun --nproc_per_node=2 src/zeroband/train.py @configs/debug.toml +``` + +## run test + +You need a machine with a least two gpus to run the full test suite. + +Some test must be run from the root directory. + +```bash +uv run pytest +``` + diff --git a/configs/150M/3090.toml b/configs/150M/3090.toml new file mode 100644 index 00000000..866d6054 --- /dev/null +++ b/configs/150M/3090.toml @@ -0,0 +1,12 @@ +name_model = "150M" +project = "debug_150m_zero_band" + +[train] +micro_bs = 16 # change this base on the gpu +sharding_strategy = "NO_SHARD" + +[optim] +batch_size = 512 +warmup_steps = 1000 +total_steps = 88_000 +lr = 4e-4 \ No newline at end of file diff --git a/configs/150M/A40.toml b/configs/150M/A40.toml new file mode 100644 index 00000000..e7799417 --- /dev/null +++ b/configs/150M/A40.toml @@ -0,0 +1,12 @@ +name_model = "150M" +project = "debug_150m_zero_band" + +[train] +micro_bs = 32 # change this base on the gpu +sharding_strategy = "NO_SHARD" + +[optim] +batch_size = 512 +warmup_steps = 1000 +total_steps = 88_000 +lr = 4e-4 \ No newline at end of file diff --git a/configs/150M/H100.toml b/configs/150M/H100.toml new file mode 100644 index 00000000..49a65475 --- /dev/null +++ b/configs/150M/H100.toml @@ -0,0 +1,12 @@ +name_model = "150M" +project = "debug_150m_zero_band" + +[train] +micro_bs = 64 # change this base on the gpu +sharding_strategy = "NO_SHARD" + +[optim] +batch_size = 512 +warmup_steps = 1000 +total_steps = 88_000 +lr = 4e-4 \ No newline at end of file diff --git a/configs/1B/H100.toml b/configs/1B/H100.toml new file mode 100644 index 00000000..1430dcea --- /dev/null +++ b/configs/1B/H100.toml @@ -0,0 +1,12 @@ +name_model = "1B" +project = "debug_1B_zero_band" + +[train] +micro_bs = 16 +sharding_strategy = "SHARD_GRAD_OP" + +[optim] +batch_size = 512 +warmup_steps = 1000 +total_steps = 88_000 +lr = 4e-4 \ No newline at end of file diff --git a/configs/7B/H100.toml b/configs/7B/H100.toml new file mode 100644 index 00000000..c1272c34 --- /dev/null +++ b/configs/7B/H100.toml @@ -0,0 +1,12 @@ +name_model = "7B" +project = "debug_7B_zero_band" + +[train] +micro_bs = 6 +sharding_strategy = "SHARD_GRAD_OP" + +[optim] +batch_size = 3840 +warmup_steps = 1000 +total_steps = 88_000 +lr = 6e-4 \ No newline at end of file diff --git a/configs/debug.toml b/configs/debug.toml new file mode 100644 index 00000000..e7d6e30d --- /dev/null +++ b/configs/debug.toml @@ -0,0 +1,13 @@ +name_model = "debugmodel" +project = "debug" + +[train] +micro_bs = 8 + +[optim] +batch_size = 16 +warmup_steps = 10 +total_steps = 5000 + +[data] +fake_data = true \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index f6b103e3..f5b1a711 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,9 +10,14 @@ dependencies = [ "setuptools", "transformers>=4.44.2", "datasets>=3.0.0", - "pydantic_config @ git+https://github.com/samsja/pydantic_config.git@v0.2" + "pydantic_config @ git+https://github.com/samsja/pydantic_config.git@e529c9c", + "einops" ] +[project.optional-dependencies] +all = [ + "wandb", +] [build-system] requires = ["hatchling"] diff --git a/src/zeroband/__init__.py b/src/zeroband/__init__.py index 7bcbc130..e69de29b 100644 --- a/src/zeroband/__init__.py +++ b/src/zeroband/__init__.py @@ -1,2 +0,0 @@ -def hello() -> str: - return "Hello from zeroband!" diff --git a/src/zeroband/data.py b/src/zeroband/data.py new file mode 100644 index 00000000..93b2a391 --- /dev/null +++ b/src/zeroband/data.py @@ -0,0 +1,87 @@ +from functools import partial +from typing import Any, Generator + +import torch +from torch.utils.data import DataLoader +from torch.utils.data import IterableDataset + +from datasets import load_dataset +from datasets.distributed import split_dataset_by_node + +TEST_VOCAB_SIZE = 1024 + +# TODO sami: make sure the init of the model is the same on all rank + + +class FakeTokenizedDataset(IterableDataset): + """This is a dummy dataset that generates random sequences of length seq_len and vocab_size""" + + def __init__(self, seq_len: int, vocab_size: int): + self.seq_len = seq_len + self.vocab_size = vocab_size + assert vocab_size > 3, "Vocab size must be greater than 3" + + def __iter__(self) -> Generator[dict[str, Any], Any, None]: + while True: + input_ids = torch.randint(3, self.vocab_size, (self.seq_len,)).tolist() + yield {"input_ids": input_ids} + + +def collate_causal_mask(max_seq_length: int = -1, pad_id: int = 0, ignore_index: int = -100) -> callable: + """collate function for causal mask. Fill with padding tokens if sequence is shorter than max_seq_length""" + return partial(_collate_fn_causal_mask, max_seq_length=max_seq_length, pad_id=pad_id, ignore_index=ignore_index) + + +def _collate_fn_causal_mask( + samples: list[dict[str, torch.LongTensor]], max_seq_length: int = -1, pad_id: int = 0, ignore_index: int = -100 +) -> dict[str, torch.LongTensor]: + """collate function for causal mask. Fill with padding tokens if sequence is shorter than max_seq_length. + input_ids and labels are both of size max_seq_length. + """ + + assert samples[0].keys() == {"input_ids"} + + batched = {"input_ids": [], "labels": []} + + if max_seq_length > 0: + max_seq_length += 1 # this makes sure that the effective seqlen is correct + + for sample in samples: + input_ids = torch.Tensor(sample["input_ids"]).long() + + if len(input_ids) < max_seq_length: + input_ids = torch.cat([input_ids, torch.full((max_seq_length - len(input_ids),), pad_id)]) + elif len(input_ids) > max_seq_length: + input_ids = input_ids[:max_seq_length] + + batched["input_ids"].append(input_ids[1:]) + batched["labels"].append(input_ids[:-1]) + + return {"input_ids": torch.stack(batched["input_ids"], dim=0), "labels": torch.stack(batched["labels"], dim=0)} + + +def get_dataloader( + tokenizer, world_size: int, rank: int, seq_length: int, batch_size: int, num_workers: int, fake_data: bool +) -> DataLoader: + if fake_data: + train_dataset = FakeTokenizedDataset(seq_length, TEST_VOCAB_SIZE) + else: + ds = load_dataset("allenai/c4", "en", streaming=True) + + def tokenize_function(data): + outputs = tokenizer(data["text"], truncation=True, max_length=seq_length, padding="max_length") + return outputs + + tokenized_datasets = ds.map( + tokenize_function, batched=True, remove_columns=["text", "timestamp", "url", "attention_mask"] + )["train"] + train_dataset = split_dataset_by_node(tokenized_datasets, world_size=world_size, rank=rank) + + data_collator = collate_causal_mask(max_seq_length=seq_length, pad_id=tokenizer.pad_token_id, ignore_index=-100) + + return DataLoader( + train_dataset, + collate_fn=data_collator, + batch_size=batch_size, + num_workers=num_workers, + ) diff --git a/src/zeroband/models/__init__.py b/src/zeroband/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/zeroband/models/llama/__init__.py b/src/zeroband/models/llama/__init__.py new file mode 100644 index 00000000..5250fe57 --- /dev/null +++ b/src/zeroband/models/llama/__init__.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Llama 2 is licensed under the LLAMA 2 Community License, +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from zeroband.models.llama.model import ModelArgs, Transformer + +__all__ = ["Transformer"] + +llama2_configs = { + "debugmodel": ModelArgs(dim=256, n_layers=2, n_heads=8), + "150M": ModelArgs(dim=1024, n_layers=12, n_heads=16), # todo(sami): double check this + "271M": ModelArgs(dim=1024, n_layers=16, n_heads=8), + "1B": ModelArgs(dim=2048, n_layers=18, n_heads=16), + "7B": ModelArgs(dim=4096, n_layers=32, n_heads=32), + "13B": ModelArgs(dim=5120, n_layers=40, n_heads=40), + "26B": ModelArgs(dim=5120, n_layers=80, n_heads=40), + "70B": ModelArgs( + dim=8192, + n_layers=80, + n_heads=64, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=4096, + ), +} + +llama3_configs = { + "debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000), + "8B": ModelArgs( + dim=4096, + n_layers=32, + n_heads=32, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=1024, + rope_theta=500000, + ), + "70B": ModelArgs( + dim=8192, + n_layers=80, + n_heads=64, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=4096, + rope_theta=500000, + ), + "405B": ModelArgs( + dim=16384, + n_layers=126, + n_heads=128, + n_kv_heads=8, + ffn_dim_multiplier=1.2, + multiple_of=4096, + rope_theta=500000, + ), +} + + +def get_model(name_model: str, type_model: str, vocab_size: int) -> tuple[Transformer, ModelArgs]: + """get the transformer model""" + + if type_model == "llama2": + config = llama2_configs[name_model] + elif type_model == "llama3": + config = llama3_configs[name_model] + else: + raise ValueError(f"Model type {type_model} not supported") + + config.vocab_size = vocab_size + return Transformer(config), config diff --git a/src/zeroband/models/llama/model.py b/src/zeroband/models/llama/model.py new file mode 100644 index 00000000..7a20eda7 --- /dev/null +++ b/src/zeroband/models/llama/model.py @@ -0,0 +1,444 @@ +# this code is copy pasted from the torchtitan repo https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py +# the commit at time of copy paste was commit f2a1551 + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Llama 2 is licensed under the LLAMA 2 Community License, +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + + +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from zeroband.models.norms import build_norm + + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + rope_theta: float = 10000 + + max_batch_size: int = 32 + max_seq_len: int = 2048 + # If `True`, then each transformer block init uses its layer ID, and if + # `False`, each uses the total number of transformer blocks + depth_init: bool = True + norm_type: str = "rmsnorm" + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), + and the first seqlen elements will be sliced, but dim must match x. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert 0 <= 1 < ndim + seqlen = x.shape[1] + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. + xk (torch.Tensor): Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + torch.unsqueeze(x, dim=3) + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + """ + Multi-head attention module. + + Args: + model_args (ModelArgs): Model configuration arguments. + + Attributes: + n_kv_heads (int): Number of key and value heads. + n_heads (int): Number of query heads. + n_rep (int): Number of repetitions for local heads. + head_dim (int): Dimension size of each attention head. + wq (Linear): Linear transformation for queries. + wk (Linear): Linear transformation for keys. + wv (Linear): Linear transformation for values. + wo (Linear): Linear transformation for output. + + """ + + def __init__(self, model_args: ModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.n_kv_heads = model_args.n_heads if model_args.n_kv_heads is None else model_args.n_kv_heads + self.n_rep = self.n_heads // self.n_kv_heads + self.head_dim = model_args.dim // model_args.n_heads + + self.wq = nn.Linear(model_args.dim, model_args.n_heads * self.head_dim, bias=False) + self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(model_args.n_heads * self.head_dim, model_args.dim, bias=False) + + def init_weights(self, init_std: float): + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """ + Forward pass of the attention module. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed frequency tensor. + + Returns: + torch.Tensor: Output tensor after attention. + + """ + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + + # we use casual mask for training + output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) + output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim) + output = output.view(bs, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class TransformerBlock(nn.Module): + """ + TransformerBlock Module + + Args: + layer_id (int): Identifier for the layer. + model_args (ModelArgs): Model configuration arguments. + + Attributes: + n_heads (int): Number of attention heads. + dim (int): Dimension size of the model. + head_dim (int): Dimension size of each attention head. + attention (Attention): Attention module. + feed_forward (FeedForward): FeedForward module. + layer_id (int): Identifier for the layer. + attention_norm (RMSNorm): Layer normalization for attention output. + ffn_norm (RMSNorm): Layer normalization for feedforward output. + + """ + + def __init__(self, layer_id: int, model_args: ModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.dim = model_args.dim + self.attention = Attention(model_args) + self.feed_forward = FeedForward( + dim=model_args.dim, + hidden_dim=4 * model_args.dim, + multiple_of=model_args.multiple_of, + ffn_dim_multiplier=model_args.ffn_dim_multiplier, + ) + self.layer_id = layer_id + self.num_layers = model_args.n_layers + + self.attention_norm = build_norm(model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = build_norm(model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps) + + if model_args.depth_init: + self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5 + else: + self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5 + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """ + Perform a forward pass through the TransformerBlock. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + Returns: + torch.Tensor: Output tensor after applying attention and feedforward layers. + + """ + h = x + self.attention(self.attention_norm(x), freqs_cis) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + def init_weights(self): + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.feed_forward.init_weights(self.weight_init_std) + + +class Transformer(nn.Module): + """ + Transformer Module + + Args: + model_args (ModelArgs): Model configuration arguments. + + Attributes: + model_args (ModelArgs): Model configuration arguments. + vocab_size (int): Vocabulary size. + n_layers (int): Number of layers in the model. + tok_embeddings (ParallelEmbedding): Token embeddings. + layers (torch.nn.ModuleList): List of Transformer blocks. + norm (RMSNorm): Layer normalization for the model output. + output (ColumnParallelLinear): Linear layer for final output. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + """ + + def __init__(self, model_args: ModelArgs): + super().__init__() + self.model_args = model_args + self.vocab_size = model_args.vocab_size + self.n_layers = model_args.n_layers + + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + + # TODO persistent should be set to false, since this buffer can be recomputed. + # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411, + # compile or pipeline-tracer will not correctly handle non-persistent buffers, + # so we need to fix that. (2) if we initialize pipeline-parallel models from + # a seed checkpoint rather than calling init_weights, we need freqs_cis to be + # initialized by the checkpoint, or we need to add a separate initializer for + # just the non-persistent buffers that is called after loading checkpoints. + self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) + + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.n_layers): + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) + + self.norm = build_norm(model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps) + + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) + self.init_weights() + + def init_weights(self): + """ + [Note: On ``init_weights`` vs. ``reset_parameters``] + Modules may define ``reset_parameters`` to initialize parameter values. + ``reset_parameters`` is meant to only initialize directly owned + parameters/buffers, not those of their child modules, and it can be + used to give the initial values for these tensors. + Separately, users may want custom initialization for their modules, + different from that in ``reset_parameters``. For this, we define + ``init_weights``. We only call it in the constructor of this + ``Transformer`` root module to avoid reinitializing tensors. + """ + with torch.device(self.freqs_cis.device): + self.freqs_cis = self._precompute_freqs_cis() + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + if layer is not None: + layer.init_weights() + if self.norm is not None: + self.norm.reset_parameters() + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + if self.output is not None: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def _precompute_freqs_cis(self) -> torch.Tensor: + return precompute_freqs_cis( + self.model_args.dim // self.model_args.n_heads, + # Need to compute until at least the max token limit for generation + # (use 2x max sequence length to be safe) + self.model_args.max_seq_len * 2, + self.model_args.rope_theta, + ) + + def forward(self, tokens: torch.Tensor): + """ + Perform a forward pass through the Transformer model. + + Args: + tokens (torch.Tensor): Input token indices. + + Returns: + torch.Tensor: Output logits after applying the Transformer model. + + """ + # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + + for layer in self.layers.values(): + h = layer(h, self.freqs_cis) + + h = self.norm(h) if self.norm else h + output = self.output(h).float() if self.output else h + return output + + @classmethod + def from_model_args(cls, model_args: ModelArgs) -> "Transformer": + """ + Initialize a Transformer model from a ModelArgs object. + + Args: + model_args (ModelArgs): Model configuration arguments. + + Returns: + Transformer: Transformer model. + + """ + return cls(model_args) diff --git a/src/zeroband/models/norms.py b/src/zeroband/models/norms.py new file mode 100644 index 00000000..cd5c2f81 --- /dev/null +++ b/src/zeroband/models/norms.py @@ -0,0 +1,333 @@ +# this code is copy pasted from the torchtitan repo https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py +# the commit at time of copy paste was commit f2a1551 + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from functools import partial + +import torch +import torch.nn as nn + +import triton +import triton.language as tl + +from torch.distributed._tensor import Partial, Replicate, Shard +from torch.distributed._tensor.experimental import local_map + + +def build_norm(norm_type: str, dim: int, eps: float = 1e-6): + """ + Builds the specified normalization layer based on the norm_type. + + Args: + norm_type (str): The type of normalization layer to build. + Supported types: layernorm, np_layernorm, rmsnorm, fused_rmsnorm + dim (int): The dimension of the normalization layer. + eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6. + + Returns: + The built normalization layer. + + Raises: + NotImplementedError: If an unknown norm_type is provided. + """ + norm_type = norm_type.lower() # Normalize to lowercase + + if norm_type == "layernorm": + return nn.LayerNorm(dim, eps=eps, bias=False) + elif norm_type == "np_layernorm": + return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False) + elif norm_type == "rmsnorm": + return RMSNorm(dim, eps=eps) + elif norm_type == "fused_rmsnorm": + return FusedRMSNorm(dim, eps=eps) + else: + raise NotImplementedError(f"Unknown norm_type: '{norm_type}'") + + +class FusedRMSNorm(nn.Module): + """Fused RMS Norm, wraps a fused Triton Kernel""" + + def __init__( + self, + dim: int, + eps: float = 1e-6, + ): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + self.fused_rms_norm_fn = fused_rms_norm_fn + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """leverages Triton Fused RMS Norm kernel""" + return self.fused_rms_norm_fn( + x, + self.weight, + eps=self.eps, + ) + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) # type: ignore + + +class RMSNorm(nn.Module): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: torch.Tensor): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: torch.Tensor): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) # type: ignore + + +# FusedRMSNorm in Triton + +# Credit +# Tri Dao's Triton LayerNorm: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py +# Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N"], +) +@triton.jit +def _rms_norm_fwd_kernel( + X, + stride_x, + Y, + stride_y, + W, + Rstd, + eps, + M, # num rows + N, # num cols + block_N: tl.constexpr, +): + row = tl.program_id(0) + cols = tl.arange(0, block_N) + + # Load input data and weights + mask = cols < N + x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) + w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) + + # Compute mean and variance + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + + # Store the reciprocal standard deviation + tl.store(Rstd + row, rstd) + + # Normalize and apply linear transformation + x_hat = x * rstd + y = x_hat * w + + # Write output + tl.store(Y + row * stride_y + cols, y, mask=mask) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N"], +) +@triton.jit +def _rms_norm_bwd_kernel_sm( + X, + stride_x, + W, + DY, + stride_dy, + DX, + stride_dx, + Rstd, + DW, + eps, + M, # num rows + N, # num cols + rows_per_program, + block_N: tl.constexpr, +): + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + cols = tl.arange(0, block_N) + mask = cols < N + + # Load weights + w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) + + # Accumulate gradients for weights + dw = tl.zeros((block_N,), dtype=tl.float32) + + row_end = min(row_start + rows_per_program, M) + for row in range(row_start, row_end): + # Load input, output gradient, and reciprocal standard deviation + x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to(tl.float32) + rstd = tl.load(Rstd + row) + + # Compute normalized input and gradients + x_hat = x * rstd + wdy = w * dy + dw += dy * x_hat + c1 = tl.sum(x_hat * wdy, axis=0) / N + dx = (wdy - x_hat * c1) * rstd + + # Store input gradient + tl.store(DX + row * stride_dx + cols, dx, mask=mask) + + # Store weight gradients + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + + +class TritonFusedRMSNorm(torch.autograd.Function): + @partial( + local_map, + out_placements=[Shard(1)], + in_placements=(None, [Shard(1)], [Replicate()], None), + ) + @staticmethod + def forward(ctx, x, weight, eps): + x_shape_start = x.shape + + # Flatten input + x = x.view(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if weight.stride(-1) != 1: + weight = weight.contiguous() + + M, N = x.shape + y = torch.empty_like(x) + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + + max_size = 65536 // x.element_size() + block_N = min(max_size, triton.next_power_of_2(N)) + + if N > block_N: + raise ValueError(f"N {N} must be <= {block_N=}") + + grid = lambda meta: (M,) # noqa: E731 + _rms_norm_fwd_kernel[grid]( + x, + x.stride(0), + y, + y.stride(0), + weight, + rstd, + eps, + M, + N, + block_N, + ) + + ctx.eps = eps + ctx.save_for_backward(x, weight, rstd) + ctx.x_shape_start = x_shape_start + + y = y.reshape(x_shape_start) + return y + + @partial( + local_map, + out_placements=([Shard(1)], [Partial()], None), + in_placements=(None, [Shard(1)]), + ) + @staticmethod + def backward(ctx, dy): + x, weight, rstd = ctx.saved_tensors + eps = ctx.eps + x_shape_start = ctx.x_shape_start + + # Flatten input and output gradients + dy = dy.view(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + + M, N = dy.shape + dx = torch.empty_like(x) + + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + + max_size = 65536 // x.element_size() + block_N = min(max_size, triton.next_power_of_2(N)) + rows_per_sm = math.ceil(M / sm_count) + + if N > block_N: + raise ValueError(f"N {N} must be <= {block_N=}") + + grid = lambda meta: (sm_count,) # noqa: E731 + _rms_norm_bwd_kernel_sm[grid]( + x, + x.stride(0), + weight, + dy, + dy.stride(0), + dx, + dx.stride(0), + rstd, + _dw, + eps, + M, + N, + rows_per_sm, + block_N, + ) + dw = _dw.sum(0).to(weight.dtype) + dx = dx.view(x_shape_start) + return dx, dw, None + + +# expose fusedRMSNorm as a function +def fused_rms_norm_fn( + x, + weight, + eps=1e-6, +): + return TritonFusedRMSNorm.apply( + x, + weight, + eps, + ) diff --git a/src/zeroband/train.py b/src/zeroband/train.py new file mode 100644 index 00000000..9988e767 --- /dev/null +++ b/src/zeroband/train.py @@ -0,0 +1,251 @@ +import os +from contextlib import nullcontext +import time +from typing import Literal + +import torch +from pydantic_config import parse_argv, BaseConfig +from torch.distributed import destroy_process_group, init_process_group +from einops import rearrange +from torch.nn import functional as F + +from transformers import ( + AutoTokenizer, + get_cosine_schedule_with_warmup, +) +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, + MixedPrecision, +) +import torch.distributed as dist +from zeroband import utils + +from zeroband.utils import get_sharding_strategy +from zeroband.utils.monitor import WandbMonitor, DummyMonitor +from zeroband.data import TEST_VOCAB_SIZE, get_dataloader +from zeroband.models.llama import get_model +from zeroband.utils.world_info import get_world_info +from zeroband.utils.logging import get_logger + + +def ddp_setup(): + """ + Initialize the distributed process group. + """ + init_process_group() + torch.cuda.set_device(world_info.local_rank) + + +class DilocoConfig(BaseConfig): + outer_lr: float = 0.7 + inner_steps: int = 10 + + +class DataConfig(BaseConfig): + seq_length: int = 1024 + fake_data: bool = False + num_workers: int = 4 + + +class OptimConfig(BaseConfig): + lr: float = 4e-4 + weight_decay: float = 0.1 + adam_betas1: float = 0.9 + adam_betas2: float = 0.95 + + warmup_steps: int = 1000 + total_steps: int = 88_000 + batch_size: int = 512 + + +class TrainConfig(BaseConfig): + micro_bs: int + torch_compile: bool = True + sharding_strategy: str = "SHARD_GRAD_OP" + + +class Config(BaseConfig): + # main config + name_model: Literal["debugmodel", "150M", "271M", "1B", "7B", "13B", "26B", "70B"] = "150M" + type_model: Literal["llama2", "llama3"] = "llama2" + + project: str = "zeroband" + metric_logger_type: Literal["wandb", "dummy"] = "wandb" + + # sub config + diloco: DilocoConfig | None = None + data: DataConfig = DataConfig() + optim: OptimConfig = OptimConfig() + train: TrainConfig + + +def train(config: Config): + sharding_strategy = get_sharding_strategy(config.train.sharding_strategy) + + # batch_size is the total batch size for all GPUs + assert config.optim.batch_size % world_info.local_world_size == 0 + batch_size = config.optim.batch_size // world_info.local_world_size + + assert batch_size % config.train.micro_bs == 0 + gradient_accumulation_steps = batch_size // config.train.micro_bs + + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True) + tokenizer.pad_token = "" # todo(sami): remove padding tokens once we have context stuffing + + logger.debug("tokenizer loaded") + + train_dataloader = get_dataloader( + tokenizer=tokenizer, + world_size=world_info.world_size, + rank=world_info.rank, + seq_length=config.data.seq_length, + batch_size=config.train.micro_bs, + num_workers=config.data.num_workers, + fake_data=config.data.fake_data, + ) + + model, model_config = get_model( + config.name_model, + config.type_model, + vocab_size=tokenizer.vocab_size if config.name_model != "debugmodel" else TEST_VOCAB_SIZE, + ) + model = model.to(world_info.local_rank) + logger.debug("model loaded") + + gpu_peak_flops = utils.get_peak_flops(torch.cuda.get_device_name(torch.device("cuda"))) + logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}") + + num_params = utils.get_num_params(model, exclude_embedding=True) + logger.info(f"Number of parameters: {num_params}") + num_flop_per_token = utils.get_num_flop_per_token( + num_params, + model_config, + config.data.seq_length, + ) + + model = FSDP( + model, + sharding_strategy=sharding_strategy, + mixed_precision=MixedPrecision(param_dtype=torch.bfloat16), + use_orig_params=True, + ) + + if config.train.torch_compile: + model = torch.compile(model) + logger.debug("model compiled and fsdped") + + # Setup optimizers + inner_optimizer = torch.optim.AdamW( + model.parameters(), + lr=config.optim.lr, + weight_decay=config.optim.weight_decay, + betas=(config.optim.adam_betas1, config.optim.adam_betas2), + ) + + scheduler = get_cosine_schedule_with_warmup( + inner_optimizer, + num_warmup_steps=config.optim.warmup_steps, + num_training_steps=config.optim.total_steps, + ) + + model.train() + + if world_info.rank == 0: + logger_cls = WandbMonitor if config.metric_logger_type == "wandb" else DummyMonitor + metric_logger = logger_cls(project=config.project, config=config.model_dump(), resume=False) + + train_dataloader_iterator = iter(train_dataloader) + + outer_step = 0 + num_inner_steps = config.diloco.inner_steps if config.diloco is not None else 1 + + logger.info("starting training") + while True: + if num_inner_steps > 1: + # if we don't use diloco we don't print the outer step logs + logger.info(f"outer_step step: {outer_step}") + + for inner_step in range(num_inner_steps): + loss_batch = 0 + beginning_step_time = time.time() + + for grad_acc_step in range(gradient_accumulation_steps): + is_accumulating = grad_acc_step < gradient_accumulation_steps - 1 + batch = next(train_dataloader_iterator) + input_ids = batch["input_ids"].to("cuda") + labels = batch["labels"].to("cuda") + + with model.no_sync() if is_accumulating else nullcontext(): + logits = model(tokens=input_ids).contiguous() + flatten_logits = rearrange(logits, "b seq vocab -> (b seq) vocab") + flatten_labels = rearrange(labels, "b seq -> (b seq)") + + loss = ( + F.cross_entropy(flatten_logits, flatten_labels, ignore_index=-100) / gradient_accumulation_steps + ) + loss.backward() + loss_batch += loss.detach() + + model.clip_grad_norm_(1.0) # gradient clipping + inner_optimizer.step() + scheduler.step() + inner_optimizer.zero_grad() + + # logging + real_step = outer_step * num_inner_steps + inner_step + 1 # add + 1 because inner_step start at 0 + inner_lr = [group["lr"] for group in inner_optimizer.param_groups][0] + + dist.all_reduce(loss_batch, op=dist.ReduceOp.AVG) + # syncing loss across all data parallel rank + # todo(sami): when using diloco make sure that the loss is computed only on local world + + time_taken = time.time() - beginning_step_time + tokens_per_second = config.data.seq_length * config.optim.batch_size / time_taken + + mfu = 100 * num_flop_per_token * tokens_per_second / gpu_peak_flops + + metrics = { + "Loss": loss_batch.item(), + "step": real_step, + "inner_lr": inner_lr, + "tokens_per_second": tokens_per_second, + "Perplexity": torch.exp(loss_batch).item(), + "total_tokens": real_step * config.optim.batch_size * config.data.seq_length, + "mfu": mfu, + } + + if world_info.rank == 0: + metric_logger.log(metrics) + + logger.info( + f"step: {real_step}, loss: {loss_batch.item():.4f}, tokens_per_second: {metrics['tokens_per_second']:.2f}, mfu: {mfu:.2f}" + ) + + outer_step += 1 + + if real_step >= config.optim.total_steps: + # we only allow to break outisde of the inner loop. + # This avoid ending the training in the middle of a the inner loop + # Since ckpt strategy and all reduce is done at the outer loop level. + break + + if world_info.rank == 0: + metric_logger.finish() + + +if __name__ == "__main__": + # Allow eager fallback during production so that that the training runs dont die + # However, in development, we want to know that we broke torch compile + torch._dynamo.config.suppress_errors = "ZERO_BAND_DEV" not in os.environ + torch.set_float32_matmul_precision("high") + + world_info = get_world_info() + logger = get_logger() + + ddp_setup() + + config = Config(**parse_argv()) + logger.debug(f"config: {config.model_dump()}") + + train(config) + destroy_process_group() diff --git a/src/zeroband/utils/__init__.py b/src/zeroband/utils/__init__.py new file mode 100644 index 00000000..a9b8fad2 --- /dev/null +++ b/src/zeroband/utils/__init__.py @@ -0,0 +1,68 @@ +import torch +from torch.distributed.fsdp import ShardingStrategy + + +__all__ = ["get_sharding_strategy", "get_peak_flops", "get_num_flop_per_token", "get_num_params"] + + +def get_sharding_strategy(sharding_strategy: str) -> ShardingStrategy: + if sharding_strategy == "FULL_SHARD": + return ShardingStrategy.FULL_SHARD + elif sharding_strategy == "SHARD_GRAD_OP": + return ShardingStrategy.SHARD_GRAD_OP + elif sharding_strategy == "NO_SHARD": + return ShardingStrategy.NO_SHARD + elif sharding_strategy == "HYBRID_SHARD": + return ShardingStrategy.HYBRID_SHARD + elif sharding_strategy == "_HYBRID_SHARD_ZERO2": + return ShardingStrategy._HYBRID_SHARD_ZERO2 + else: + raise ValueError( + f"Invalid sharding_strategy: {sharding_strategy}. Please choose 'FULL_SHARD', 'SHARD_GRAD_OP', 'NO_SHARD', 'HYBRID_SHARD', or '_HYBRID_SHARD_ZERO2'." + ) + + +### code above inspired and copied from https://github.com/pytorch/torchtitan/blob/4b3f2e41a084bf79a8540068ed525539d1244edd/torchtitan/utils.py#L119 + + +# hardcoded BF16 type peak flops for NVIDIA A100 and H100 GPU +def get_peak_flops(device_name: str) -> int: + if "A100" in device_name: + # data from https://www.nvidia.com/en-us/data-center/a100/ + return 312e12 + elif "H100" in device_name: + # data from https://www.nvidia.com/en-us/data-center/h100/ + # NOTE: Specifications are one-half lower without sparsity. + if "NVL" in device_name: + return 835e12 + elif "PCIe" in device_name: + return 756e12 + else: # for H100 SXM and other variants + return 989e12 + else: # for other GPU types, assume A100 + return 312e12 + + +def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int: + l, h, q, t = ( # noqa: E741 + model_config.n_layers, + model_config.n_heads, + model_config.dim // model_config.n_heads, + seq_len, + ) + # Reasoning behind the factor of 12 for the self-attention part of the formula: + # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) + # 2. the flash attention does 1 more matmul recomputation in the backward + # but recomputation should not be counted in calculating MFU (+0) + # 3. each matmul performs 1 multiplication and 1 addition (*2) + # 4. we follow the convention and do not account for sparsity in causal attention + flop_per_token = 6 * num_params + 12 * l * h * q * t + + return flop_per_token + + +def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int: + num_params = sum(p.numel() for p in model.parameters()) + if exclude_embedding: + num_params -= model.tok_embeddings.weight.numel() + return num_params diff --git a/src/zeroband/utils/logging.py b/src/zeroband/utils/logging.py new file mode 100644 index 00000000..885ae3cd --- /dev/null +++ b/src/zeroband/utils/logging.py @@ -0,0 +1,40 @@ +import logging +import os + +from zeroband.utils.world_info import get_world_info + +logger = None + + +class CustomFormatter(logging.Formatter): + def __init__(self, local_rank: int): + super().__init__() + self.local_rank = local_rank + + def format(self, record): + log_format = "{asctime} [{levelname}] [Rank {local_rank}] {message}" + formatter = logging.Formatter(log_format, style="{", datefmt="%H:%M:%S") + record.local_rank = self.local_rank # Add this line to set the local rank in the record + return formatter.format(record) + + +def get_logger(): + global logger # Add this line to modify the global logger variable + if logger is not None: + return logger + + world_info = get_world_info() + logger = logging.getLogger(__name__) + + if world_info.local_rank == 0: + log_level = os.getenv("ZERO_BAND_LOG_LEVEL", "INFO") + logging.basicConfig(level=getattr(logging, log_level, logging.INFO)) + else: + logging.basicConfig(level=logging.CRITICAL) # Disable logging for non-zero ranks + + handler = logging.StreamHandler() + handler.setFormatter(CustomFormatter(world_info.local_rank)) + logger.addHandler(handler) + logger.propagate = False # Prevent the log messages from being propagated to the root logger + + return logger diff --git a/src/zeroband/utils/monitor.py b/src/zeroband/utils/monitor.py new file mode 100644 index 00000000..532515ef --- /dev/null +++ b/src/zeroband/utils/monitor.py @@ -0,0 +1,49 @@ +import pickle +from typing import Any, Protocol +import importlib + + +class Monitor(Protocol): + def __init__(self, project, config): ... + + def log(self, metrics: dict[str, Any]): ... + + def finish(self): ... + + +class WandbMonitor: + def __init__(self, project, config, resume: bool): + if importlib.util.find_spec("wandb") is None: + raise ImportError("wandb is not installed. Please install it to use WandbMonitor.") + + import wandb + + wandb.init( + project=project, config=config, resume="auto" if resume else None + ) # make wandb reuse the same run id if possible + + def log(self, metrics: dict[str, Any]): + import wandb + + wandb.log(metrics) + + def finish(self): + import wandb + + wandb.finish() + + +class DummyMonitor: + def __init__(self, project, config, *args, **kwargs): + self.project = project + self.config = config + open(project, "a").close() # Create an empty file at the project path + + self.data = [] + + def log(self, metrics: dict[str, Any]): + self.data.append(metrics) + + def finish(self): + with open(self.project, "wb") as f: + pickle.dump(self.data, f) diff --git a/src/zeroband/utils/world_info.py b/src/zeroband/utils/world_info.py new file mode 100644 index 00000000..6ab3780f --- /dev/null +++ b/src/zeroband/utils/world_info.py @@ -0,0 +1,28 @@ +import os + +world_info = None + + +class WorldInfo: + """This class parse env var about torch world into class variables.""" + + world_size: int + rank: int + local_rank: int + local_world_size: int + + def __init__(self): + self.world_size = int(os.environ["WORLD_SIZE"]) + self.rank = int(os.environ["RANK"]) + self.local_rank = int(os.environ["LOCAL_RANK"]) + self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + + +def get_world_info() -> WorldInfo: + """ + Return a WorldInfo singleton. + """ + global world_info + if world_info is None: + world_info = WorldInfo() + return world_info diff --git a/tests/test_configs.py b/tests/test_configs.py new file mode 100644 index 00000000..0873750c --- /dev/null +++ b/tests/test_configs.py @@ -0,0 +1,19 @@ +""" +Tests all of the config file. usefull to catch mismatch key after a renaming of a arg name +Need to be run from the root folder +""" + +import os +from zeroband.train import Config +import pytest +import tomli + +config_file_names = [file for file in os.listdir("configs") if file.endswith(".toml")] + + +@pytest.mark.parametrize("config_file_name", config_file_names) +def test_load_config(config_file_name): + with open(f"configs/{config_file_name}", "rb") as f: + content = tomli.load(f) + config = Config(**content) + assert config is not None diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 00000000..8e7ef928 --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,22 @@ +import pytest +import torch +from zeroband.models.llama import Transformer, llama2_configs + + +VOCAB_SIZE = 1024 + + +@pytest.fixture +def llama_config(): + config = llama2_configs["debugmodel"] + config.vocab_size = VOCAB_SIZE + return config + + +def test_llama(llama_config): + seq_len = 512 + bs = 8 + model = Transformer(llama_config) + input_ = torch.randint(0, llama_config.vocab_size, (bs, seq_len)) + output = model(input_) + assert output.shape == (bs, seq_len, llama_config.vocab_size) diff --git a/tests/test_torchrun/test_train b/tests/test_torchrun/test_train new file mode 100644 index 00000000..3295d5f5 --- /dev/null +++ b/tests/test_torchrun/test_train @@ -0,0 +1,206 @@ +import pickle +import subprocess +import numpy as np +import pytest +import socket +from hivemind.dht.dht import DHT +from open_diloco.ckpt_utils import CKPT_PREFIX + + +def get_random_available_port(): + # https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +@pytest.fixture(scope="session") +def random_available_port(): + return get_random_available_port() + + +@pytest.fixture +def config() -> list[str]: + return [ + "--path_model", + "tests/models/llama-2m-fresh", + "--fake_data", + "--no-torch_compile", + "--lr", + "1e-2", + "--per_device_train_batch_size", + "8", + "--total_batch_size", + "16", + "--max_steps", + "50", + "--metric_logger_type", + "dummy", + ] + + +@pytest.mark.parametrize("num_gpu", [2]) +def test_multi_gpu_ckpt(config, random_available_port, num_gpu, tmp_path): + ckpt_path = f"{tmp_path}/ckpt" + log_file_1 = f"{tmp_path}/log1.json" + log_file_2 = f"{tmp_path}/log2.json" + + run_1 = ["--ckpt.path", ckpt_path, "--ckpt.interval", "10", "--project", log_file_1] + + cmd = [ + "torchrun", + f"--nproc_per_node={num_gpu}", + "--rdzv-endpoint", + f"localhost:{random_available_port}", + "open_diloco/train_fsdp.py", + *config, + ] + + result = subprocess.run(cmd + run_1) + + if result.returncode != 0: + pytest.fail(f"Process {result} failed {result.stderr}") + + run_2 = ["--ckpt.path", ckpt_path, "--ckpt.resume", f"{ckpt_path}/{CKPT_PREFIX}_20", "--project", log_file_2] + + results_resume = subprocess.run(cmd + run_2) + + if results_resume.returncode != 0: + pytest.fail(f"Process {result} failed {result.stderr}") + + with open(log_file_1, "rb") as f: + log1 = pickle.load(f) + with open(log_file_2, "rb") as f: + log2 = pickle.load(f) + + log1 = {data["step"]: [data["Loss"], data["lr"]] for data in log1} + log2 = {data["step"]: [data["Loss"], data["lr"]] for data in log2} + + common_step = set(log1.keys()) & set(log2.keys()) + + for step in common_step: + assert np.allclose(log1[step][0], log2[step][0], atol=1e-3), f"Loss at step {step} is different" + assert log1[step][1] == log2[step][1], f"Lr at step {step} is different" + + +@pytest.fixture +def config_hv() -> list[str]: + config = [ + "--path_model", + "tests/models/llama-2m-fresh", + "--fake_data", + "--no-torch_compile", + "--lr", + "1e-2", + "--per_device_train_batch_size", + "8", + "--total_batch_size", + "16", + "--max_steps", + "100", + "--metric_logger_type", + "dummy", + ] + + return config + [ + "--hv.local_steps", + "25", + "--hv.skip_load_from_peers", + "--hv.fail_rank_drop", + "--hv.matchmaking_time", + "5", + ] + + +@pytest.mark.parametrize("num_diloco", [2]) +def test_multi_gpu_hivemind(config_hv, num_diloco, tmp_path): + dht = DHT( + start=True, + host_maddrs=[f"/ip4/0.0.0.0/tcp/{get_random_available_port()}"], + ) + + initial_peers = str(dht.get_visible_maddrs()[0]) + + results = [] + + ckpt_path = f"{tmp_path}/ckpt" + + def get_base_cmd(i, initial_peers): + return [ + "torchrun", + f"--nproc_per_node={1}", + "--rdzv-endpoint", + f"localhost:{port}", + "open_diloco/train_fsdp.py", + *config_hv, + "--hv.initial_peers", + initial_peers, + "--hv.world_rank", + str(i), + "--hv.galaxy_size", + str(num_diloco), + ] + + for i in range(num_diloco): + port = get_random_available_port() + + cmd = get_base_cmd(i, initial_peers) + [ + "--ckpt.path", + ckpt_path, + "--ckpt.interval", + "25", + "--project", + f"{tmp_path}/log{i}_part1.json", + ] + + result = subprocess.Popen(cmd) + results.append(result) + + for result in results: + result.wait() + if result.returncode != 0: + pytest.fail(f"Process {result} failed {result.stderr}") + + # resume from ckpt + + dht.shutdown() + + del dht + dht = DHT( + start=True, + host_maddrs=[f"/ip4/0.0.0.0/tcp/{get_random_available_port()}"], + ) + initial_peers = str(dht.get_visible_maddrs()[0]) + + for i in range(num_diloco): + port = get_random_available_port() + + cmd = get_base_cmd(i, initial_peers) + [ + "--ckpt.resume", + f"{ckpt_path}/{CKPT_PREFIX}_50", + "--project", + f"{tmp_path}/log{i}_part2.json", + ] + + result = subprocess.Popen(cmd) + results.append(result) + + for result in results: + result.wait() + if result.returncode != 0: + pytest.fail(f"Process {result} failed {result.stderr}") + + for i in range(num_diloco): + with open(f"{tmp_path}/log{i}_part1.json", "rb") as f: + log1 = pickle.load(f) + with open(f"{tmp_path}/log{i}_part2.json", "rb") as f: + log2 = pickle.load(f) + + log1 = {data["step"]: [data["Loss"], data["lr"]] for data in log1} + log2 = {data["step"]: [data["Loss"], data["lr"]] for data in log2} + + common_step = set(log1.keys()) & set(log2.keys()) + + for step in common_step: + assert np.allclose(log1[step][0], log2[step][0], atol=1e-2), f"Loss at step {step} is different" + assert log1[step][1] == log2[step][1], f"Lr at step {step} is different" diff --git a/tests/test_torchrun/test_train.py b/tests/test_torchrun/test_train.py new file mode 100644 index 00000000..61c7ec74 --- /dev/null +++ b/tests/test_torchrun/test_train.py @@ -0,0 +1,40 @@ +import subprocess +import pytest +import socket + + +def get_random_available_port(): + # https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +@pytest.fixture() +def random_available_port(): + return get_random_available_port() + + +@pytest.fixture() +def config_path() -> str: + # need to be executed in the root dir + return "configs/debug.toml" + + +@pytest.mark.parametrize("num_gpu", [1, 2]) +def test_multi_gpu_ckpt(config_path, random_available_port, num_gpu): + cmd = [ + "torchrun", + f"--nproc_per_node={num_gpu}", + "--rdzv-endpoint", + f"localhost:{random_available_port}", + "src/zeroband/train.py", + f"@{config_path}", + "--optim.total_steps", + "10", + ] + + result = subprocess.run(cmd) + + if result.returncode != 0: + pytest.fail(f"Process {result} failed {result.stderr}") diff --git a/uv.lock b/uv.lock index 0d981321..f0ec766e 100644 --- a/uv.lock +++ b/uv.lock @@ -1,10 +1,18 @@ version = 1 requires-python = ">=3.10" resolution-markers = [ - "python_full_version < '3.11'", - "python_full_version == '3.11.*'", - "python_full_version < '3.13'", - "python_full_version >= '3.13'", + "python_full_version < '3.11' and sys_platform == 'linux'", + "python_full_version < '3.11' and sys_platform != 'linux'", + "python_full_version == '3.11.*' and sys_platform == 'linux'", + "python_full_version == '3.11.*' and sys_platform != 'linux'", + "python_full_version < '3.11' and sys_platform == 'linux'", + "python_full_version == '3.11.*' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and sys_platform == 'linux'", + "python_full_version < '3.11' and sys_platform != 'linux'", + "python_full_version == '3.11.*' and sys_platform != 'linux'", + "python_full_version == '3.12.*' and sys_platform != 'linux'", + "python_full_version >= '3.13' and sys_platform == 'linux'", + "python_full_version >= '3.13' and sys_platform != 'linux'", ] [[package]] @@ -204,6 +212,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/28/76/e6222113b83e3622caa4bb41032d0b1bf785250607392e1b778aca0b8a7d/charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc", size = 48543 }, ] +[[package]] +name = "click" +version = "8.1.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "platform_system == 'Windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/2e/d53fa4befbf2cfa713304affc7ca780ce4fc1fd8710527771b58311a3229/click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28", size = 97941 }, +] + [[package]] name = "colorama" version = "0.4.6" @@ -256,6 +276,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8e/41/9307e4f5f9976bc8b7fea0b66367734e8faf3ec84bc0d412d8cfabbb66cd/distlib-0.3.8-py2.py3-none-any.whl", hash = "sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784", size = 468850 }, ] +[[package]] +name = "docker-pycreds" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c5/e6/d1f6c00b7221e2d7c4b470132c931325c8b22c51ca62417e300f5ce16009/docker-pycreds-0.4.0.tar.gz", hash = "sha256:6ce3270bcaf404cc4c3e27e4b6c70d3521deae82fb508767870fdbf772d584d4", size = 8754 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f5/e8/f6bd1eee09314e7e6dee49cbe2c5e22314ccdb38db16c9fc72d2fa80d054/docker_pycreds-0.4.0-py2.py3-none-any.whl", hash = "sha256:7266112468627868005106ec19cd0d722702d2b7d5912a28e19b826c3d37af49", size = 8982 }, +] + +[[package]] +name = "einops" +version = "0.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/79/ca/9f5dcb8bead39959454c3912266bedc4c315839cee0e0ca9f4328f4588c1/einops-0.8.0.tar.gz", hash = "sha256:63486517fed345712a8385c100cb279108d9d47e6ae59099b07657e983deae85", size = 58861 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/44/5a/f0b9ad6c0a9017e62d4735daaeb11ba3b6c009d69a26141b258cd37b5588/einops-0.8.0-py3-none-any.whl", hash = "sha256:9572fb63046264a862693b0a87088af3bdc8c068fde03de63453cbbde245465f", size = 43223 }, +] + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -342,6 +383,30 @@ http = [ { name = "aiohttp" }, ] +[[package]] +name = "gitdb" +version = "4.0.11" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "smmap" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/19/0d/bbb5b5ee188dec84647a4664f3e11b06ade2bde568dbd489d9d64adef8ed/gitdb-4.0.11.tar.gz", hash = "sha256:bf5421126136d6d0af55bc1e7c1af1c397a34f5b7bd79e776cd3e89785c2b04b", size = 394469 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fd/5b/8f0c4a5bb9fd491c277c21eff7ccae71b47d43c4446c9d0c6cff2fe8c2c4/gitdb-4.0.11-py3-none-any.whl", hash = "sha256:81a3407ddd2ee8df444cbacea00e2d038e40150acfa3001696fe0dcf1d3adfa4", size = 62721 }, +] + +[[package]] +name = "gitpython" +version = "3.1.43" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "gitdb" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b6/a1/106fd9fa2dd989b6fb36e5893961f82992cf676381707253e0bf93eb1662/GitPython-3.1.43.tar.gz", hash = "sha256:35f314a9f878467f5453cc1fee295c3e18e52f1b99f10f6cf5b1682e968a9e7c", size = 214149 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/bd/cc3a402a6439c15c3d4294333e13042b915bbeab54edc457c723931fed3f/GitPython-3.1.43-py3-none-any.whl", hash = "sha256:eec7ec56b92aad751f9912a73404bc02ba212a23adb2c7098ee668417051a1ff", size = 207337 }, +] + [[package]] name = "huggingface-hub" version = "0.24.6" @@ -831,6 +896,37 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/07/92/caae8c86e94681b42c246f0bca35c059a2f0529e5b92619f6aba4cf7e7b6/pre_commit-3.8.0-py2.py3-none-any.whl", hash = "sha256:9a90a53bf82fdd8778d58085faf8d83df56e40dfe18f45b19446e26bf1b3a63f", size = 204643 }, ] +[[package]] +name = "protobuf" +version = "5.28.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b1/a4/4579a61de526e19005ceeb93e478b61d77aa38c8a85ad958ff16a9906549/protobuf-5.28.2.tar.gz", hash = "sha256:59379674ff119717404f7454647913787034f03fe7049cbef1d74a97bb4593f0", size = 422494 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/30/231764750e0987755b7b8d66771f161e5f002e165d27b72154c776dbabf7/protobuf-5.28.2-cp310-abi3-win32.whl", hash = "sha256:eeea10f3dc0ac7e6b4933d32db20662902b4ab81bf28df12218aa389e9c2102d", size = 419662 }, + { url = "https://files.pythonhosted.org/packages/7d/46/3fdf7462160135aee6a530f1ec66665b5b4132fa2e1002ab971bc6ec2589/protobuf-5.28.2-cp310-abi3-win_amd64.whl", hash = "sha256:2c69461a7fcc8e24be697624c09a839976d82ae75062b11a0972e41fd2cd9132", size = 431479 }, + { url = "https://files.pythonhosted.org/packages/37/45/d2a760580f8f2ed2825ba44cb370e0a4011ddef85e728f46ea3dd565a8a5/protobuf-5.28.2-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:a8b9403fc70764b08d2f593ce44f1d2920c5077bf7d311fefec999f8c40f78b7", size = 414736 }, + { url = "https://files.pythonhosted.org/packages/e6/23/ed718dc18e6a561445ece1e7a17d2dda0c634ad9cf663102b47f10005d8f/protobuf-5.28.2-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:35cfcb15f213449af7ff6198d6eb5f739c37d7e4f1c09b5d0641babf2cc0c68f", size = 316518 }, + { url = "https://files.pythonhosted.org/packages/23/08/a1ce0415a115c2b703bfa798f06f0e43ca91dbe29d6180bf86a9287b15e2/protobuf-5.28.2-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:5e8a95246d581eef20471b5d5ba010d55f66740942b95ba9b872d918c459452f", size = 316605 }, + { url = "https://files.pythonhosted.org/packages/9b/55/f24e3b801d2e108c48aa2b1b59bb791b5cffba89465cbbf66fc98de89270/protobuf-5.28.2-py3-none-any.whl", hash = "sha256:52235802093bd8a2811abbe8bf0ab9c5f54cca0a751fdd3f6ac2a21438bffece", size = 169566 }, +] + +[[package]] +name = "psutil" +version = "6.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/18/c7/8c6872f7372eb6a6b2e4708b88419fb46b857f7a2e1892966b851cc79fc9/psutil-6.0.0.tar.gz", hash = "sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2", size = 508067 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c5/66/78c9c3020f573c58101dc43a44f6855d01bbbd747e24da2f0c4491200ea3/psutil-6.0.0-cp27-none-win32.whl", hash = "sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35", size = 249766 }, + { url = "https://files.pythonhosted.org/packages/e1/3f/2403aa9558bea4d3854b0e5e567bc3dd8e9fbc1fc4453c0aa9aafeb75467/psutil-6.0.0-cp27-none-win_amd64.whl", hash = "sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1", size = 253024 }, + { url = "https://files.pythonhosted.org/packages/0b/37/f8da2fbd29690b3557cca414c1949f92162981920699cd62095a984983bf/psutil-6.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0", size = 250961 }, + { url = "https://files.pythonhosted.org/packages/35/56/72f86175e81c656a01c4401cd3b1c923f891b31fbcebe98985894176d7c9/psutil-6.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0", size = 287478 }, + { url = "https://files.pythonhosted.org/packages/19/74/f59e7e0d392bc1070e9a70e2f9190d652487ac115bb16e2eff6b22ad1d24/psutil-6.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd", size = 290455 }, + { url = "https://files.pythonhosted.org/packages/cd/5f/60038e277ff0a9cc8f0c9ea3d0c5eb6ee1d2470ea3f9389d776432888e47/psutil-6.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2e8d0054fc88153ca0544f5c4d554d42e33df2e009c4ff42284ac9ebdef4132", size = 292046 }, + { url = "https://files.pythonhosted.org/packages/8b/20/2ff69ad9c35c3df1858ac4e094f20bd2374d33c8643cf41da8fd7cdcb78b/psutil-6.0.0-cp37-abi3-win32.whl", hash = "sha256:a495580d6bae27291324fe60cea0b5a7c23fa36a7cd35035a16d93bdcf076b9d", size = 253560 }, + { url = "https://files.pythonhosted.org/packages/73/44/561092313ae925f3acfaace6f9ddc4f6a9c748704317bad9c8c8f8a36a79/psutil-6.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:33ea5e1c975250a720b3a6609c490db40dae5d83a4eb315170c4fe0d8b1f34b3", size = 257399 }, + { url = "https://files.pythonhosted.org/packages/7c/06/63872a64c312a24fb9b4af123ee7007a306617da63ff13bcc1432386ead7/psutil-6.0.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0", size = 251988 }, +] + [[package]] name = "pyarrow" version = "17.0.0" @@ -880,7 +976,7 @@ wheels = [ [[package]] name = "pydantic-config" version = "0.2.0" -source = { git = "https://github.com/samsja/pydantic_config.git?rev=v0.2#e50503071fcfbfdbfd9442dc45eb853a4033565d" } +source = { git = "https://github.com/samsja/pydantic_config.git?rev=e529c9c#e529c9ca7f3bd5581e2e8bab013faa6d2996810a" } dependencies = [ { name = "pydantic" }, { name = "rich" }, @@ -1227,6 +1323,67 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/19/46/5d11dc300feaad285c2f1bd784ff3f689f5e0ab6be49aaf568f3a77019eb/safetensors-0.4.5-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:21742b391b859e67b26c0b2ac37f52c9c0944a879a25ad2f9f9f3cd61e7fda8f", size = 606660 }, ] +[[package]] +name = "sentry-sdk" +version = "2.14.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3c/23/6527e56fb17817153c37d702d6b9ed0a2f75ed213fd98a176c1b8894ad20/sentry_sdk-2.14.0.tar.gz", hash = "sha256:1e0e2eaf6dad918c7d1e0edac868a7bf20017b177f242cefe2a6bcd47955961d", size = 282948 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/40/de/956ce1d71459fa1af0486ca141fc605ac16f7c8855750668ff663e2b436a/sentry_sdk-2.14.0-py2.py3-none-any.whl", hash = "sha256:b8bc3dc51d06590df1291b7519b85c75e2ced4f28d9ea655b6d54033503b5bf4", size = 311425 }, +] + +[[package]] +name = "setproctitle" +version = "1.3.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ff/e1/b16b16a1aa12174349d15b73fd4b87e641a8ae3fb1163e80938dbbf6ae98/setproctitle-1.3.3.tar.gz", hash = "sha256:c913e151e7ea01567837ff037a23ca8740192880198b7fbb90b16d181607caae", size = 27253 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4f/cc/c51e6371f640a9adbe693ddb89d68596e5a8e4b5e05b4d3c65ec504e2f6d/setproctitle-1.3.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:897a73208da48db41e687225f355ce993167079eda1260ba5e13c4e53be7f754", size = 16954 }, + { url = "https://files.pythonhosted.org/packages/c3/7d/d03f319e0f3b3a6e98731a56cd4d81478ed0c12531b822fd2c728b948edb/setproctitle-1.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8c331e91a14ba4076f88c29c777ad6b58639530ed5b24b5564b5ed2fd7a95452", size = 11304 }, + { url = "https://files.pythonhosted.org/packages/9c/56/6f4a4e80b2810eb7ea9ab355022c780ef80457de368ab5b6b21b795e4f05/setproctitle-1.3.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bbbd6c7de0771c84b4aa30e70b409565eb1fc13627a723ca6be774ed6b9d9fa3", size = 31249 }, + { url = "https://files.pythonhosted.org/packages/d0/ae/010811bece9a59a8bba131d9e7acea9c2e3c3cbf544bf06d8b10b8c28ff5/setproctitle-1.3.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c05ac48ef16ee013b8a326c63e4610e2430dbec037ec5c5b58fcced550382b74", size = 32594 }, + { url = "https://files.pythonhosted.org/packages/87/7b/69bdc791001250dff279a1a81904f3f563caece4fa1607a95b9fd5197d6e/setproctitle-1.3.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1342f4fdb37f89d3e3c1c0a59d6ddbedbde838fff5c51178a7982993d238fe4f", size = 29713 }, + { url = "https://files.pythonhosted.org/packages/79/e7/54b36be02aee8ad573be68f6f46fd62838735c2f007b22df50eb5e13a20d/setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc74e84fdfa96821580fb5e9c0b0777c1c4779434ce16d3d62a9c4d8c710df39", size = 30755 }, + { url = "https://files.pythonhosted.org/packages/69/a7/2a77b68c11db87c22350381d6ce022011eb420076790e0e3697153e89458/setproctitle-1.3.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9617b676b95adb412bb69645d5b077d664b6882bb0d37bfdafbbb1b999568d85", size = 38562 }, + { url = "https://files.pythonhosted.org/packages/9d/09/bc108723bbfb7c50c22fdf22191f3e32abcb5d6f46610018030b25f601c5/setproctitle-1.3.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6a249415f5bb88b5e9e8c4db47f609e0bf0e20a75e8d744ea787f3092ba1f2d0", size = 36991 }, + { url = "https://files.pythonhosted.org/packages/94/ad/4166381d79f6ae8138be9b49f05d193a8deb748debace9896dffad45a753/setproctitle-1.3.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:38da436a0aaace9add67b999eb6abe4b84397edf4a78ec28f264e5b4c9d53cd5", size = 39866 }, + { url = "https://files.pythonhosted.org/packages/3d/92/17168f4bb1a695094e93e73a1ef1f7b89953a6d91e8a7699a2c840ba712f/setproctitle-1.3.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:da0d57edd4c95bf221b2ebbaa061e65b1788f1544977288bdf95831b6e44e44d", size = 38221 }, + { url = "https://files.pythonhosted.org/packages/0c/1b/753432a877bcdfb099e280795c86ac7dc245d9651b98308f606bb3db610d/setproctitle-1.3.3-cp310-cp310-win32.whl", hash = "sha256:a1fcac43918b836ace25f69b1dca8c9395253ad8152b625064415b1d2f9be4fb", size = 11064 }, + { url = "https://files.pythonhosted.org/packages/29/ff/80a02c5b414c2d3ff49c36c0a571a94aa3b4236f07eee39f72ebdb7314a0/setproctitle-1.3.3-cp310-cp310-win_amd64.whl", hash = "sha256:200620c3b15388d7f3f97e0ae26599c0c378fdf07ae9ac5a13616e933cbd2086", size = 11815 }, + { url = "https://files.pythonhosted.org/packages/c9/17/7f9d5ddf4cfc4386e74565ccf63b8381396336e4629bb165b52b803ceddb/setproctitle-1.3.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:334f7ed39895d692f753a443102dd5fed180c571eb6a48b2a5b7f5b3564908c8", size = 16948 }, + { url = "https://files.pythonhosted.org/packages/ff/5d/77edf4c29c8d6728b49d3f0abb22159bb9c0c4ddebd721c09486b34985c8/setproctitle-1.3.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:950f6476d56ff7817a8fed4ab207727fc5260af83481b2a4b125f32844df513a", size = 11305 }, + { url = "https://files.pythonhosted.org/packages/13/f0/263954ca925a278036f100405e7ba82d4341e1e6bdc09f35362a7b40f684/setproctitle-1.3.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:195c961f54a09eb2acabbfc90c413955cf16c6e2f8caa2adbf2237d1019c7dd8", size = 31578 }, + { url = "https://files.pythonhosted.org/packages/79/52/503b546da451deb78fde27fec96c39d3f63a7958be60c9a837de89f47a0d/setproctitle-1.3.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f05e66746bf9fe6a3397ec246fe481096664a9c97eb3fea6004735a4daf867fd", size = 32910 }, + { url = "https://files.pythonhosted.org/packages/48/72/aeb734419a58a85ca7845c3d0011c322597da4ff601ebbc28f6c1dfd1ae8/setproctitle-1.3.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b5901a31012a40ec913265b64e48c2a4059278d9f4e6be628441482dd13fb8b5", size = 30086 }, + { url = "https://files.pythonhosted.org/packages/fd/df/44b267cb8f073a4ae77e120f0705ab3a07165ad90cecd4881b34c7e1e37b/setproctitle-1.3.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64286f8a995f2cd934082b398fc63fca7d5ffe31f0e27e75b3ca6b4efda4e353", size = 31076 }, + { url = "https://files.pythonhosted.org/packages/82/c2/79ad43c914418cb1920e0198ac7326061c05cd4ec75c86ed0ca456b7e957/setproctitle-1.3.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:184239903bbc6b813b1a8fc86394dc6ca7d20e2ebe6f69f716bec301e4b0199d", size = 41226 }, + { url = "https://files.pythonhosted.org/packages/81/1b/0498c36a07a73d39a7070f45d96a299006e624efc07fc2e2296286237316/setproctitle-1.3.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:664698ae0013f986118064b6676d7dcd28fefd0d7d5a5ae9497cbc10cba48fa5", size = 39723 }, + { url = "https://files.pythonhosted.org/packages/3a/fe/ebbcffd6012b9cf5edb017a9c30cfc2beccf707f5bf495da8cf69b4abe69/setproctitle-1.3.3-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e5119a211c2e98ff18b9908ba62a3bd0e3fabb02a29277a7232a6fb4b2560aa0", size = 42773 }, + { url = "https://files.pythonhosted.org/packages/64/b1/5786c0442435eb18d04299c8ce7d1f86feb5154444ac684963527a76e169/setproctitle-1.3.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:417de6b2e214e837827067048f61841f5d7fc27926f2e43954567094051aff18", size = 41089 }, + { url = "https://files.pythonhosted.org/packages/33/fb/14b41e920406a12de0a164ef3b86d62edb4fac63d91d9f86f3b80dae5b38/setproctitle-1.3.3-cp311-cp311-win32.whl", hash = "sha256:6a143b31d758296dc2f440175f6c8e0b5301ced3b0f477b84ca43cdcf7f2f476", size = 11066 }, + { url = "https://files.pythonhosted.org/packages/7e/ba/f6da9ba74e8c2c662e932b27a01025c1bee2846222f6a2e87a69c259772f/setproctitle-1.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:a680d62c399fa4b44899094027ec9a1bdaf6f31c650e44183b50d4c4d0ccc085", size = 11817 }, + { url = "https://files.pythonhosted.org/packages/32/22/9672612b194e4ac5d9fb67922ad9d30232b4b66129b0381ab5efeb6ae88f/setproctitle-1.3.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:d4460795a8a7a391e3567b902ec5bdf6c60a47d791c3b1d27080fc203d11c9dc", size = 16917 }, + { url = "https://files.pythonhosted.org/packages/49/e5/562ff00f2f3f4253ff8fa6886e0432b8eae8cde82530ac19843d8ed2c485/setproctitle-1.3.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:bdfd7254745bb737ca1384dee57e6523651892f0ea2a7344490e9caefcc35e64", size = 11264 }, + { url = "https://files.pythonhosted.org/packages/8f/1f/f97ea7bf71c873590a63d62ba20bf7294439d1c28603e5c63e3616c2131a/setproctitle-1.3.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:477d3da48e216d7fc04bddab67b0dcde633e19f484a146fd2a34bb0e9dbb4a1e", size = 31907 }, + { url = "https://files.pythonhosted.org/packages/66/fb/2d90806b9a2ed97c140baade3d1d2d41d3b51458300a2d999268be24d21d/setproctitle-1.3.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ab2900d111e93aff5df9fddc64cf51ca4ef2c9f98702ce26524f1acc5a786ae7", size = 33333 }, + { url = "https://files.pythonhosted.org/packages/38/39/e7ce791f5635f3a16bd21d6b79bd9280c4c4aed8ab936b4b21334acf05a7/setproctitle-1.3.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:088b9efc62d5aa5d6edf6cba1cf0c81f4488b5ce1c0342a8b67ae39d64001120", size = 30573 }, + { url = "https://files.pythonhosted.org/packages/20/22/fd76bbde4194d4e31d5b31a02f80c8e7e54a99d3d8ff34f3d656c6655689/setproctitle-1.3.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a6d50252377db62d6a0bb82cc898089916457f2db2041e1d03ce7fadd4a07381", size = 31601 }, + { url = "https://files.pythonhosted.org/packages/51/5c/a6257cc68e17abcc4d4a78cc6666aa0d3805af6d942576625c4a468a72f0/setproctitle-1.3.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:87e668f9561fd3a457ba189edfc9e37709261287b52293c115ae3487a24b92f6", size = 40717 }, + { url = "https://files.pythonhosted.org/packages/db/31/4f0faad7ef641be4e8dfcbc40829775f2d6a4ca1ff435a4074047fa3dad1/setproctitle-1.3.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:287490eb90e7a0ddd22e74c89a92cc922389daa95babc833c08cf80c84c4df0a", size = 39384 }, + { url = "https://files.pythonhosted.org/packages/22/17/8763dc4f9ddf36af5f043ceec213b0f9f45f09fd2d5061a89c699aabe8b0/setproctitle-1.3.3-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:4fe1c49486109f72d502f8be569972e27f385fe632bd8895f4730df3c87d5ac8", size = 42350 }, + { url = "https://files.pythonhosted.org/packages/7b/b2/2403cecf2e5c5b4da22f7d9df4b2149bf92d03a3422185e682e81055549c/setproctitle-1.3.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4a6ba2494a6449b1f477bd3e67935c2b7b0274f2f6dcd0f7c6aceae10c6c6ba3", size = 40704 }, + { url = "https://files.pythonhosted.org/packages/5e/c1/11e80061ac06aece2a0ffcaf018cdc088aebb2fc586f68201755518532ad/setproctitle-1.3.3-cp312-cp312-win32.whl", hash = "sha256:2df2b67e4b1d7498632e18c56722851ba4db5d6a0c91aaf0fd395111e51cdcf4", size = 11057 }, + { url = "https://files.pythonhosted.org/packages/90/e8/ece468e93e99d3b2826e9649f6d03e80f071d451e20c742f201f77d1bea1/setproctitle-1.3.3-cp312-cp312-win_amd64.whl", hash = "sha256:f38d48abc121263f3b62943f84cbaede05749047e428409c2c199664feb6abc7", size = 11809 }, + { url = "https://files.pythonhosted.org/packages/24/55/8b369b56007a5a2c7594cdb58cd4a09d7cca65b28483bb5582c6975663f1/setproctitle-1.3.3-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:6b9e62ddb3db4b5205c0321dd69a406d8af9ee1693529d144e86bd43bcb4b6c0", size = 10726 }, + { url = "https://files.pythonhosted.org/packages/35/30/ac99ecae8458ba995f85aa3aa911004679b405922e1487b0fba6fe8f4d37/setproctitle-1.3.3-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e3b99b338598de0bd6b2643bf8c343cf5ff70db3627af3ca427a5e1a1a90dd9", size = 13368 }, + { url = "https://files.pythonhosted.org/packages/70/1d/3b2249c833c7d52b59ff0602d760df0543dc1e6c272f145b949750edeb01/setproctitle-1.3.3-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38ae9a02766dad331deb06855fb7a6ca15daea333b3967e214de12cfae8f0ef5", size = 12969 }, + { url = "https://files.pythonhosted.org/packages/76/78/97f36752438cb5c6409b53eb3b1a334827cede43acab65e4fc4a0014cf9f/setproctitle-1.3.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:200ede6fd11233085ba9b764eb055a2a191fb4ffb950c68675ac53c874c22e20", size = 11848 }, +] + [[package]] name = "setuptools" version = "74.1.2" @@ -1245,6 +1402,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d9/5a/e7c31adbe875f2abbb91bd84cf2dc52d792b5a01506781dbcf25c91daf11/six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254", size = 11053 }, ] +[[package]] +name = "smmap" +version = "5.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/88/04/b5bf6d21dc4041000ccba7eb17dd3055feb237e7ffc2c20d3fae3af62baa/smmap-5.0.1.tar.gz", hash = "sha256:dceeb6c0028fdb6734471eb07c0cd2aae706ccaecab45965ee83f11c8d3b1f62", size = 22291 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/a5/10f97f73544edcdef54409f1d839f6049a0d79df68adbc1ceb24d1aaca42/smmap-5.0.1-py3-none-any.whl", hash = "sha256:e6d8668fa5f93e706934a62d7b4db19c8d9eb8cf2adbb75ef1b675aa332b69da", size = 24282 }, +] + [[package]] name = "sympy" version = "1.13.2" @@ -1397,7 +1563,7 @@ name = "triton" version = "3.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "filelock" }, + { name = "filelock", marker = "python_full_version < '3.13'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/45/27/14cc3101409b9b4b9241d2ba7deaa93535a217a211c86c4cc7151fb12181/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a", size = 209376304 }, @@ -1446,6 +1612,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5d/ea/12f774a18b55754c730c8383dad8f10d7b87397d1cb6b2b944c87381bb3b/virtualenv-20.26.4-py3-none-any.whl", hash = "sha256:48f2695d9809277003f30776d155615ffc11328e6a0a8c1f0ec80188d7874a55", size = 6013327 }, ] +[[package]] +name = "wandb" +version = "0.18.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "docker-pycreds" }, + { name = "gitpython" }, + { name = "platformdirs" }, + { name = "protobuf" }, + { name = "psutil" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "sentry-sdk" }, + { name = "setproctitle" }, + { name = "setuptools" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f6/e4/ca1da2dde43886e7daf2260e4dcbd4daed9b00599ee12432cadc2dab4ca3/wandb-0.18.1.tar.gz", hash = "sha256:d625e94d53ff4ff961c58a9a17f0a1ea35720d98b9db710a458235924469fc6b", size = 6238045 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/25/5b/ab5c2e69c9f49fdea2f83c3f8e15d9388a92e9c5639dd3618a2d5d5cd144/wandb-0.18.1-py3-none-any.whl", hash = "sha256:be936a193eeb940ce03d966f013b847562497e76256852d5fb170cdcdf50f185", size = 5125929 }, + { url = "https://files.pythonhosted.org/packages/f5/8d/298e1a8e1c101894b0805e197667d910e3c0ed46ce537d26c5d3ec1081f1/wandb-0.18.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:1f143b814b0fd51b5f1a676ad8b66bd06a5ee4ad22fc46bcbf24048d76c77d35", size = 6636762 }, + { url = "https://files.pythonhosted.org/packages/aa/fc/6832f3546ee43db973748dd0153a1e6c11b1af5cf29bc1187498620f83f3/wandb-0.18.1-py3-none-macosx_11_0_x86_64.whl", hash = "sha256:86b73a9f94f18b07f0e937ae945560244b560b57c16a9dfb8f03e2516d0cc666", size = 6708580 }, + { url = "https://files.pythonhosted.org/packages/dd/66/5c5e76b0c5a0016d9b935e961ce4444ec280af43af7512258490533630d9/wandb-0.18.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc404682ebfb2477b48cb436a331e1bea0262e002d6fb3ccafe71d13657dd4ee", size = 9281298 }, + { url = "https://files.pythonhosted.org/packages/a8/64/6b1549a02151c3b8426e54fc7011733fa284483151a0189c85b309c9ec4e/wandb-0.18.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd4c97d69242efd604c1a2077c8b56341e236cfaca78c40f59dcef9b95464fdc", size = 9663908 }, + { url = "https://files.pythonhosted.org/packages/85/5e/bbf9937120b6a95cf859179eb77eee7bde7f63c365efa23c36ca2331c579/wandb-0.18.1-py3-none-win32.whl", hash = "sha256:33c5a0d74bc28879917b519f24d69b0e81530d72e99aba1c115189a2c9aac9cf", size = 6787975 }, + { url = "https://files.pythonhosted.org/packages/63/a8/397fb9a7d6e78136efd6765744f7a992c3c9a119f13448ded2b4885b88e7/wandb-0.18.1-py3-none-win_amd64.whl", hash = "sha256:559cbd6e9ab752622f7d6dacdc334ede7f1bc34f42df3f48ed32bde55db42c6e", size = 6787977 }, +] + [[package]] name = "xxhash" version = "3.5.0" @@ -1598,6 +1792,7 @@ version = "0.1.0" source = { editable = "." } dependencies = [ { name = "datasets" }, + { name = "einops" }, { name = "numpy" }, { name = "pydantic-config" }, { name = "setuptools" }, @@ -1605,6 +1800,11 @@ dependencies = [ { name = "transformers" }, ] +[package.optional-dependencies] +all = [ + { name = "wandb" }, +] + [package.dev-dependencies] dev = [ { name = "pre-commit" }, @@ -1615,11 +1815,13 @@ dev = [ [package.metadata] requires-dist = [ { name = "datasets", specifier = ">=3.0.0" }, + { name = "einops" }, { name = "numpy" }, - { name = "pydantic-config", git = "https://github.com/samsja/pydantic_config.git?rev=v0.2" }, + { name = "pydantic-config", git = "https://github.com/samsja/pydantic_config.git?rev=e529c9c" }, { name = "setuptools" }, { name = "torch", specifier = "==2.4.1" }, { name = "transformers", specifier = ">=4.44.2" }, + { name = "wandb", marker = "extra == 'all'" }, ] [package.metadata.requires-dev]