diff --git a/.github/workflows/gpu.yml b/.github/workflows/gpu.yml index c2196cb3..fb2ceec7 100644 --- a/.github/workflows/gpu.yml +++ b/.github/workflows/gpu.yml @@ -24,16 +24,10 @@ jobs: enable-cache: true cache-dependency-glob: "uv.lock" - - - name: Set up Python run: uv python install 3.10.13 - name: Install the project run: uv sync --all-extras --dev - - - name: Install flash attention - run: uv pip install flash-attn --no-build-isolation - - name: Run tests run: uv run pytest tests \ No newline at end of file diff --git a/README.md b/README.md index 2bee50b9..a1035756 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,6 @@ sudo apt install iperf -y uv venv source .venv/bin/activate uv sync --extra all -uv pip install flash-attn==2.6.3 --no-build-isolation git submodule update --init --recursive ``` diff --git a/pyproject.toml b/pyproject.toml index 3176bd31..da393919 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ description = "ZeroBand is a production ready codebase for decentralized trainin readme = "README.md" requires-python = ">=3.10" dependencies = [ - "torch==2.4.1", + "torch==2.5.1", "numpy", "setuptools", "transformers>=4.44.2", diff --git a/scripts/install/install.sh b/scripts/install/install.sh index 5c1480c4..070b8e73 100755 --- a/scripts/install/install.sh +++ b/scripts/install/install.sh @@ -46,10 +46,7 @@ main() { log_info "Installing dependencies..." uv sync --extra all - - log_info "Installing flash-attn..." - uv pip install flash-attn==2.6.3 --no-build-isolation - + log_info "Updating git submodules..." git submodule update --init --recursive diff --git a/src/zeroband/data.py b/src/zeroband/data.py index be1219db..ed90f61c 100644 --- a/src/zeroband/data.py +++ b/src/zeroband/data.py @@ -125,7 +125,7 @@ def load_state_dict(self, state_dict): self.state = SequencePackingDataSetState(**state_dict["state"]) -def collate_fn(samples: list[dict[str, torch.LongTensor]]) -> dict[str, torch.LongTensor]: +def collate_fn(samples: list[dict[str, torch.LongTensor]]) -> dict[str, torch.LongTensor | list[torch.LongTensor]]: assert samples[0].keys() == {"input_ids", "labels", "seqlens"} inputs_ids = [] @@ -136,12 +136,12 @@ def collate_fn(samples: list[dict[str, torch.LongTensor]]) -> dict[str, torch.Lo inputs_ids.append(sample["input_ids"]) labels.append(sample["labels"]) - seqlens.extend(sample["seqlens"]) + seqlens.append(torch.Tensor(sample["seqlens"]).long()) return { "input_ids": torch.stack(inputs_ids, dim=0), "labels": torch.stack(labels, dim=0), - "seqlens": torch.Tensor(seqlens).long(), + "seqlens": seqlens, } diff --git a/src/zeroband/models/llama/__init__.py b/src/zeroband/models/llama/__init__.py index 5dbd2c97..30b54963 100644 --- a/src/zeroband/models/llama/__init__.py +++ b/src/zeroband/models/llama/__init__.py @@ -81,7 +81,10 @@ def get_model( - name_model: str, type_model: str, vocab_size: int, seq_length: int, attn_fn: str + name_model: str, + type_model: str, + vocab_size: int, + seq_length: int, ) -> tuple[Transformer, ModelArgs]: """get the transformer model""" @@ -94,6 +97,5 @@ def get_model( config.vocab_size = vocab_size config.max_seq_len = seq_length - config.attn_fn = attn_fn return Transformer(config), config diff --git a/src/zeroband/models/llama/model.py b/src/zeroband/models/llama/model.py index 90756f78..234b8fee 100644 --- a/src/zeroband/models/llama/model.py +++ b/src/zeroband/models/llama/model.py @@ -12,18 +12,31 @@ from dataclasses import dataclass -from importlib.util import find_spec -from typing import Literal, Optional, Tuple -from einops import rearrange +from typing import Optional, Tuple import torch import torch.nn.functional as F from torch import nn from zeroband.models.norms import build_norm -flash_attn_available = find_spec("flash_attn") is not None -if flash_attn_available: - from flash_attn import flash_attn_func, flash_attn_varlen_func +from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask, _DEFAULT_SPARSE_BLOCK_SIZE + +_flex_attention_compiled = torch.compile(flex_attention, dynamic=False) + + +# copied from https://github.com/pytorch/torchtune/blob/f2bd4bc25b24587aef40f486087412b9da8f1d94/torchtune/modules/attention_utils.py#L27 +# We cannot do nested compile, but flex attention only has perf benefits +# when compiled. To insulate it from the compiler, we wrap it with +# compiler.disable so that it can be used regardless of whether the model +# is compiled or not, and flex attention always remains compiled. +@torch.compiler.disable(recursive=False) +def flex_attention_compiled( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_mask: BlockMask, +) -> torch.Tensor: + return _flex_attention_compiled(q, k, v, block_mask=block_mask) @dataclass @@ -45,8 +58,6 @@ class ModelArgs: depth_init: bool = True norm_type: str = "fused_rmsnorm" - attn_fn: Literal["sdpa", "flash"] = "sdpa" - def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: """ @@ -138,6 +149,48 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: ) +def seqlens_to_docs_tensor(seqlens: list[torch.Tensor]) -> torch.Tensor: + """Converts list of sequence lengths to document indices tensor. + Example: + seqlens = [tensor([2,2,1]), tensor([2,2,1])] # List of 2 tensors + docs = [[0,0,1,1,2], [0,0,1,1,2]] # Each doc_id repeated per its length + """ + return torch.stack([torch.repeat_interleave(torch.arange(len(seq), device=seq.device), seq) for seq in seqlens]) + + +def create_block_mask_from_seqlens(seqlens: list[torch.Tensor]) -> BlockMask: + """Creates a block mask from a list of sequence lengths. + + Example: + seqlens = [tensor([2,2,1]))] # List of 2 tensors + docs = [[0,0,1,1,2]] # Each doc_id repeated per its length + + mask = [[1 1 0 0 0] # First token of doc 0 can see itself and second token of doc 0 + [1 1 0 0 0] # Second token of doc 0 can see both tokens of doc 0 + [0 0 1 1 0] # First token of doc 1 can see itself and second token of doc 1 + [0 0 1 1 0] # Second token of doc 1 can see both tokens of doc 1 + [0 0 0 0 1]] # Token of doc 2 can only see itself + """ + docs = seqlens_to_docs_tensor(seqlens).to("cuda") + batch_size, max_seq_len = docs.shape + + def document_causal_mask(b, h, q_idx, kv_idx): + causal_mask = q_idx >= kv_idx + document_mask = docs[b, q_idx] == docs[b, kv_idx] + return causal_mask & document_mask + + return create_block_mask( + document_causal_mask, + batch_size, + None, + max_seq_len, + max_seq_len, + device="cuda", + _compile=True, + BLOCK_SIZE=max_seq_len if max_seq_len < _DEFAULT_SPARSE_BLOCK_SIZE else _DEFAULT_SPARSE_BLOCK_SIZE, + ) + + class Attention(nn.Module): """ Multi-head attention module. @@ -164,8 +217,6 @@ def __init__(self, model_args: ModelArgs): self.n_rep = self.n_heads // self.n_kv_heads self.head_dim = model_args.dim // model_args.n_heads - self.attn_fn = model_args.attn_fn - self.wq = nn.Linear(model_args.dim, model_args.n_heads * self.head_dim, bias=False) self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) @@ -180,7 +231,7 @@ def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, - seqlens: torch.Tensor | None = None, + block_mask: BlockMask | None = None, ): """ Forward pass of the attention module. @@ -214,7 +265,7 @@ def forward( xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - output = self.self_attention(xq, xk, xv, seqlens) + output = self.self_attention(xq, xk, xv, block_mask) output = output.view(bs, seqlen, -1) return self.wo(output) @@ -224,55 +275,21 @@ def _sdpa_attention(self, xq, xk, xv) -> torch.Tensor: output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim) return output - def _flash_attention(self, xq, xk, xv) -> torch.Tensor: - q = rearrange(xq, "b n t h -> b t n h") - k = rearrange(xk, "b n t h -> b t n h") - v = rearrange(xv, "b n t h -> b t n h") - # q/k/b is [b, nh, t, hs] but fa2 expected [b , t, nh, hs] - return flash_attn_func(q, k, v, causal=True) - - def _fa_attention_with_seqlens(self, xq, xk, xv, seqlens) -> torch.Tensor: - b = xq.shape[0] - cu_seqlens = ( - torch.concat([torch.tensor([0]).to(xq.device), seqlens.cumsum(0)], dim=0).to(torch.int32).to(xq.device) - ) - max_seqlen = seqlens.max() - - q = rearrange(xq, "b n t h -> (b t) n h") - k = rearrange(xk, "b n t h -> (b t) n h") - v = rearrange(xv, "b n t h -> (b t) n h") - # q/k/v is [b, nh, t, hs] but fa expected [b * t, nh, hs] - - y = flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - causal=True, - ) - - y = rearrange(y, "(b t) n h -> b t n h", b=b) - return y + def _flex_attention_with_seqlens(self, xq, xk, xv, block_mask: BlockMask) -> torch.Tensor: + output = flex_attention_compiled(xq, xk, xv, block_mask=block_mask) + output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim) + return output + # output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) + # output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim) + # return output def self_attention( - self, xq: torch.Tensor, xk: torch.Tensor, xv: torch.Tensor, seqlens: torch.Tensor | None = None + self, xq: torch.Tensor, xk: torch.Tensor, xv: torch.Tensor, block_mask: BlockMask | None = None ) -> torch.Tensor: - if self.attn_fn == "sdpa": - if seqlens is not None: - raise NotImplementedError("SDPA with seqlens is not implemented.") - return self._sdpa_attention(xq, xk, xv) - elif self.attn_fn == "flash": - if not flash_attn_available: - raise RuntimeError("Flash attention is not available. Please install flash_attn.") - if seqlens is not None: - return self._fa_attention_with_seqlens(xq, xk, xv, seqlens) - else: - return self._flash_attention(xq, xk, xv) + if block_mask is not None: + return self._flex_attention_with_seqlens(xq, xk, xv, block_mask) else: - raise ValueError(f"Unknown attention function: {self.attn_fn}") + return self._sdpa_attention(xq, xk, xv) class FeedForward(nn.Module): @@ -365,7 +382,7 @@ def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, - seqlens: torch.Tensor | None = None, + block_mask: BlockMask | None = None, ): """ Perform a forward pass through the TransformerBlock. @@ -378,7 +395,7 @@ def forward( torch.Tensor: Output tensor after applying attention and feedforward layers. """ - h = x + self.attention(self.attention_norm(x), freqs_cis, seqlens=seqlens) + h = x + self.attention(self.attention_norm(x), freqs_cis, block_mask=block_mask) out = h + self.feed_forward(self.ffn_norm(h)) return out @@ -475,14 +492,13 @@ def _precompute_freqs_cis(self) -> torch.Tensor: self.model_args.rope_theta, ) - def forward(self, tokens: torch.Tensor, seqlens: torch.Tensor | None = None): + def forward(self, tokens: torch.Tensor, block_mask: BlockMask | None = None): """ Perform a forward pass through the Transformer model. Args: tokens (torch.Tensor): Input token indices. - seqlens (torch.Tensor | None): Sequence lengths tensor for packing. - + block_mask (BlockMask | None): Block mask for attention. Returns: torch.Tensor: Output logits after applying the Transformer model. @@ -491,7 +507,7 @@ def forward(self, tokens: torch.Tensor, seqlens: torch.Tensor | None = None): h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): - h = layer(h, self.freqs_cis, seqlens=seqlens) + h = layer(h, self.freqs_cis, block_mask=block_mask) h = self.norm(h) if self.norm else h output = self.output(h).float() if self.output else h diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 14d02b99..7ab7cb8d 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -1,6 +1,7 @@ import os from typing import Literal import time +import warnings import psutil from pydantic import model_validator from multiprocessing.process import _children @@ -19,6 +20,7 @@ from zeroband.diloco import Diloco, DilocoConfig from zeroband.comms import ElasticDeviceMesh from zeroband.loss import cross_entropy_max_z_loss +from zeroband.models.llama.model import create_block_mask_from_seqlens from zeroband.utils import ( FakeTokenizer, @@ -75,12 +77,12 @@ class TrainConfig(BaseConfig): memory_profiler: MemoryProfilerConfig | None = None sequence_packing: bool = True - attn_fn: Literal["flash", "sdpa"] = "flash" + attn_fn: Literal["flash", "sdpa"] | None = None @model_validator(mode="after") def validate_attn_fn(self): - if self.attn_fn == "sdpa" and self.sequence_packing: - raise ValueError("SDPA does not support sequence packing") + if self.attn_fn is not None: + warnings.warn("attn_fn argument is deprecated") return self @@ -162,7 +164,6 @@ def train(config: Config): config.type_model, vocab_size=len(tokenizer) if config.name_model != "debugmodel" or not config.data.fake else TEST_VOCAB_SIZE, seq_length=config.data.seq_length, - attn_fn=config.train.attn_fn, ) model = model.to(world_info.local_rank) @@ -360,14 +361,12 @@ def train(config: Config): input_ids = batch["input_ids"].to("cuda") labels = batch["labels"].to("cuda") if config.train.sequence_packing: - seqlens = batch["seqlens"].to("cuda") - # seqlens has a dynamic shape but fixed dimension, this allow to still torch compile - # https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html - torch._dynamo.mark_dynamic(seqlens, 0) + seqlens = [seqlen.to("cuda") for seqlen in batch["seqlens"]] + block_mask = create_block_mask_from_seqlens(seqlens) if seqlens is not None else None else: - seqlens = None + block_mask = None - logits = model(tokens=input_ids, seqlens=seqlens).contiguous() + logits = model(tokens=input_ids, block_mask=block_mask).contiguous() flatten_logits = rearrange(logits, "b seq vocab -> (b seq) vocab") flatten_labels = rearrange(labels, "b seq -> (b seq)") diff --git a/tests/test_model.py b/tests/test_model.py index debdc842..7853cb22 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -2,7 +2,7 @@ import pytest import torch from zeroband.models.llama import Transformer, llama2_configs -from zeroband.models.llama.model import Attention, ModelArgs +from zeroband.models.llama.model import Attention, ModelArgs, create_block_mask_from_seqlens VOCAB_SIZE = 1024 @@ -26,11 +26,9 @@ def llama_config() -> ModelArgs: return config -@pytest.mark.parametrize("attn_fn", ["flash", "sdpa"]) -def test_llama(llama_config: ModelArgs, attn_fn): +def test_llama(llama_config: ModelArgs): seq_len = 512 bs = 8 - llama_config.attn_fn = attn_fn model = Transformer(llama_config).to("cuda") input_ = torch.randint(0, llama_config.vocab_size, (bs, seq_len)).to("cuda") with torch.autocast(device_type="cuda", dtype=torch.bfloat16): @@ -50,21 +48,38 @@ def test_attn(llama_config: ModelArgs): freqs_cis = get_freqs_cis(llama_config) input_ = torch.rand(bs, seq_len, llama_config.dim).to("cuda") + seqlens = [torch.Tensor([seq_len]).int().to("cuda") for _ in range(bs)] + block_mask = create_block_mask_from_seqlens(seqlens) attn = Attention(llama_config).to("cuda") - attn.attn_fn = "sdpa" with torch.autocast(device_type="cuda", dtype=torch.bfloat16): output_sdpa = attn(input_, freqs_cis) - attn.attn_fn = "flash" with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - output_fa = attn(input_, freqs_cis) + output_flex = attn(input_, freqs_cis, block_mask=block_mask) rtol = ERROR_RTOL[torch.bfloat16] atol = ERROR_ATOL[torch.bfloat16] - assert output_sdpa.shape == output_fa.shape - torch.testing.assert_close(output_sdpa, output_fa, rtol=rtol, atol=atol) + assert output_sdpa.shape == output_flex.shape + torch.testing.assert_close(output_sdpa, output_flex, rtol=rtol, atol=atol) + + +def test_packing_simple(llama_config: ModelArgs): + seq_len = 512 + bs = 8 + + freqs_cis = get_freqs_cis(llama_config) + input_ = torch.rand(bs, seq_len, llama_config.dim).to("cuda") + seqlens = [torch.Tensor([seq_len // 4] * 4).int().to("cuda") for _ in range(bs)] + block_mask = create_block_mask_from_seqlens(seqlens) + + attn = Attention(llama_config).to("cuda") + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + output = attn(input_, freqs_cis, block_mask=block_mask) + + assert output.shape == (bs, seq_len, llama_config.dim) def test_sequence_packing_two_time_same_sequence(llama_config: ModelArgs): @@ -73,22 +88,21 @@ def test_sequence_packing_two_time_same_sequence(llama_config: ModelArgs): We then pass the packed sequence to the attention layer and check that the output for each sequence is the same. """ - llama_config.attn_fn = "flash" model = Attention(llama_config).to("cuda") emb = torch.nn.Embedding(10, llama_config.dim).to("cuda") seq = [2, 1, 4, 8] input_stuff_raw = torch.Tensor([seq + seq]).long().to("cuda") - seqlens = [len(seq), len(seq)] - seqlens = torch.Tensor(seqlens).int().to("cuda") + seqlens = [torch.Tensor([len(seq), len(seq)]).int().to("cuda")] + block_mask = create_block_mask_from_seqlens(seqlens) input_stuff = emb(input_stuff_raw) freqs_cis = get_freqs_cis(llama_config) with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - output = model(input_stuff, freqs_cis, seqlens=seqlens) + output = model(input_stuff, freqs_cis, block_mask=block_mask) output_left = output[:, :4, :] output_right = output[:, 4:, :] @@ -106,7 +120,6 @@ def test_sequence_packing_vs_normal(llama_config: ModelArgs): take two sequences and compare the outout of attention on individual sequences vs the output of attention on the packed sequence """ - llama_config.attn_fn = "flash" model = Attention(llama_config).to("cuda") emb = torch.nn.Embedding(10, llama_config.dim).to("cuda") @@ -116,13 +129,13 @@ def test_sequence_packing_vs_normal(llama_config: ModelArgs): seq_2 = [3, 7, 5, 6] input_packed_raw = torch.Tensor([seq_1 + seq_2]).long().to("cuda") - seqlens = [len(seq_1), len(seq_2)] - seqlens = torch.Tensor(seqlens).int().to("cuda") + seqlens = [torch.Tensor([len(seq_1), len(seq_2)]).int().to("cuda")] + block_mask = create_block_mask_from_seqlens(seqlens) input_packed = emb(input_packed_raw) with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - output = model(input_packed, freqs_cis, seqlens=seqlens) + output = model(input_packed, freqs_cis, block_mask=block_mask) output_packed_1 = output[:, :4, :] output_packed_2 = output[:, 4:, :] @@ -153,7 +166,6 @@ def test_sequence_packing_vs_normal_random(llama_config: ModelArgs): take two sequences and compare the outout of attention on individual sequences vs the output of attention on the packed sequence """ - llama_config.attn_fn = "flash" model = Attention(llama_config).to("cuda") freqs_cis = get_freqs_cis(llama_config) @@ -163,17 +175,19 @@ def test_sequence_packing_vs_normal_random(llama_config: ModelArgs): for _ in range(10): seq_len_cutoff = random.randint(1, MAX_SEQ_LEN) - input_1 = torch.rand(1, seq_len_cutoff, llama_config.dim).to("cuda") - input_2 = torch.rand(1, MAX_SEQ_LEN - seq_len_cutoff, llama_config.dim).to("cuda") + seq1 = seq_len_cutoff + seq2 = MAX_SEQ_LEN - seq_len_cutoff + input_1 = torch.rand(1, seq1, llama_config.dim).to("cuda") + input_2 = torch.rand(1, seq2, llama_config.dim).to("cuda") - seqlens = [seq_len_cutoff, MAX_SEQ_LEN - seq_len_cutoff] - seqlens = torch.Tensor(seqlens).int().to("cuda") + seqlens = [torch.Tensor([seq1, seq2]).int().to("cuda")] + block_mask = create_block_mask_from_seqlens(seqlens) packed_input = torch.cat([input_1, input_2], dim=1) # packed output with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - output = model(packed_input, freqs_cis, seqlens=seqlens) + output = model(packed_input, freqs_cis, block_mask=block_mask) output_packed_1 = output[:, :seq_len_cutoff, :] output_packed_2 = output[:, seq_len_cutoff:, :] @@ -195,7 +209,6 @@ def test_sequence_packing_vs_normal_random(llama_config: ModelArgs): def test_end_to_end_packing(llama_config: ModelArgs): - llama_config.attn_fn = "flash" model = Transformer(llama_config).to("cuda") BS = 8 @@ -203,11 +216,10 @@ def test_end_to_end_packing(llama_config: ModelArgs): input_ = torch.randint(1, llama_config.vocab_size, (BS, SEQ_LEN)).to("cuda") - seqlens = [SEQ_LEN // 4, SEQ_LEN // 4, SEQ_LEN // 2] - seqlens = torch.Tensor(seqlens).int().to("cuda") - + seqlens = [torch.Tensor([SEQ_LEN // 4, SEQ_LEN // 4, SEQ_LEN // 2]).int().to("cuda") for _ in range(BS)] + block_mask = create_block_mask_from_seqlens(seqlens) with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - output = model(input_, seqlens=seqlens) + output = model(input_, block_mask=block_mask) assert output.shape == (BS, SEQ_LEN, llama_config.vocab_size) diff --git a/uv.lock b/uv.lock index 93ab0f5b..0b3b3f34 100644 --- a/uv.lock +++ b/uv.lock @@ -963,38 +963,42 @@ wheels = [ [[package]] name = "nvidia-cublas-cu12" -version = "12.1.3.1" +version = "12.4.5.8" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/37/6d/121efd7382d5b0284239f4ab1fc1590d86d34ed4a4a2fdb13b30ca8e5740/nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728", size = 410594774 }, - { url = "https://files.pythonhosted.org/packages/c5/ef/32a375b74bea706c93deea5613552f7c9104f961b21df423f5887eca713b/nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906", size = 439918445 }, + { url = "https://files.pythonhosted.org/packages/7f/7f/7fbae15a3982dc9595e49ce0f19332423b260045d0a6afe93cdbe2f1f624/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3", size = 363333771 }, + { url = "https://files.pythonhosted.org/packages/ae/71/1c91302526c45ab494c23f61c7a84aa568b8c1f9d196efa5993957faf906/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b", size = 363438805 }, + { url = "https://files.pythonhosted.org/packages/e2/2a/4f27ca96232e8b5269074a72e03b4e0d43aa68c9b965058b1684d07c6ff8/nvidia_cublas_cu12-12.4.5.8-py3-none-win_amd64.whl", hash = "sha256:5a796786da89203a0657eda402bcdcec6180254a8ac22d72213abc42069522dc", size = 396895858 }, ] [[package]] name = "nvidia-cuda-cupti-cu12" -version = "12.1.105" +version = "12.4.127" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/00/6b218edd739ecfc60524e585ba8e6b00554dd908de2c9c66c1af3e44e18d/nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e", size = 14109015 }, - { url = "https://files.pythonhosted.org/packages/d0/56/0021e32ea2848c24242f6b56790bd0ccc8bf99f973ca790569c6ca028107/nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4", size = 10154340 }, + { url = "https://files.pythonhosted.org/packages/93/b5/9fb3d00386d3361b03874246190dfec7b206fd74e6e287b26a8fcb359d95/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a", size = 12354556 }, + { url = "https://files.pythonhosted.org/packages/67/42/f4f60238e8194a3106d06a058d494b18e006c10bb2b915655bd9f6ea4cb1/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb", size = 13813957 }, + { url = "https://files.pythonhosted.org/packages/f3/79/8cf313ec17c58ccebc965568e5bcb265cdab0a1df99c4e674bb7a3b99bfe/nvidia_cuda_cupti_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:5688d203301ab051449a2b1cb6690fbe90d2b372f411521c86018b950f3d7922", size = 9938035 }, ] [[package]] name = "nvidia-cuda-nvrtc-cu12" -version = "12.1.105" +version = "12.4.127" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b6/9f/c64c03f49d6fbc56196664d05dba14e3a561038a81a638eeb47f4d4cfd48/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2", size = 23671734 }, - { url = "https://files.pythonhosted.org/packages/ad/1d/f76987c4f454eb86e0b9a0e4f57c3bf1ac1d13ad13cd1a4da4eb0e0c0ce9/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed", size = 19331863 }, + { url = "https://files.pythonhosted.org/packages/77/aa/083b01c427e963ad0b314040565ea396f914349914c298556484f799e61b/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0eedf14185e04b76aa05b1fea04133e59f465b6f960c0cbf4e37c3cb6b0ea198", size = 24133372 }, + { url = "https://files.pythonhosted.org/packages/2c/14/91ae57cd4db3f9ef7aa99f4019cfa8d54cb4caa7e00975df6467e9725a9f/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338", size = 24640306 }, + { url = "https://files.pythonhosted.org/packages/7c/30/8c844bfb770f045bcd8b2c83455c5afb45983e1a8abf0c4e5297b481b6a5/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:a961b2f1d5f17b14867c619ceb99ef6fcec12e46612711bcec78eb05068a60ec", size = 19751955 }, ] [[package]] name = "nvidia-cuda-runtime-cu12" -version = "12.1.105" +version = "12.4.127" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/eb/d5/c68b1d2cdfcc59e72e8a5949a37ddb22ae6cade80cd4a57a84d4c8b55472/nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40", size = 823596 }, - { url = "https://files.pythonhosted.org/packages/9f/e2/7a2b4b5064af56ea8ea2d8b2776c0f2960d95c88716138806121ae52a9c9/nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344", size = 821226 }, + { url = "https://files.pythonhosted.org/packages/a1/aa/b656d755f474e2084971e9a297def515938d56b466ab39624012070cb773/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3", size = 894177 }, + { url = "https://files.pythonhosted.org/packages/ea/27/1795d86fe88ef397885f2e580ac37628ed058a92ed2c39dc8eac3adf0619/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5", size = 883737 }, + { url = "https://files.pythonhosted.org/packages/a8/8b/450e93fab75d85a69b50ea2d5fdd4ff44541e0138db16f9cd90123ef4de4/nvidia_cuda_runtime_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:09c2e35f48359752dfa822c09918211844a3d93c100a715d79b59591130c5e1e", size = 878808 }, ] [[package]] @@ -1011,25 +1015,30 @@ wheels = [ [[package]] name = "nvidia-cufft-cu12" -version = "11.0.2.54" +version = "11.2.1.3" source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12" }, +] wheels = [ - { url = "https://files.pythonhosted.org/packages/86/94/eb540db023ce1d162e7bea9f8f5aa781d57c65aed513c33ee9a5123ead4d/nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56", size = 121635161 }, - { url = "https://files.pythonhosted.org/packages/f7/57/7927a3aa0e19927dfed30256d1c854caf991655d847a4e7c01fe87e3d4ac/nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253", size = 121344196 }, + { url = "https://files.pythonhosted.org/packages/7a/8a/0e728f749baca3fbeffad762738276e5df60851958be7783af121a7221e7/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399", size = 211422548 }, + { url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117 }, + { url = "https://files.pythonhosted.org/packages/f6/ee/3f3f8e9874f0be5bbba8fb4b62b3de050156d159f8b6edc42d6f1074113b/nvidia_cufft_cu12-11.2.1.3-py3-none-win_amd64.whl", hash = "sha256:d802f4954291101186078ccbe22fc285a902136f974d369540fd4a5333d1440b", size = 210576476 }, ] [[package]] name = "nvidia-curand-cu12" -version = "10.3.2.106" +version = "10.3.5.147" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/44/31/4890b1c9abc496303412947fc7dcea3d14861720642b49e8ceed89636705/nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0", size = 56467784 }, - { url = "https://files.pythonhosted.org/packages/5c/97/4c9c7c79efcdf5b70374241d48cf03b94ef6707fd18ea0c0f53684931d0b/nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a", size = 55995813 }, + { url = "https://files.pythonhosted.org/packages/80/9c/a79180e4d70995fdf030c6946991d0171555c6edf95c265c6b2bf7011112/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1f173f09e3e3c76ab084aba0de819c49e56614feae5c12f69883f4ae9bb5fad9", size = 56314811 }, + { url = "https://files.pythonhosted.org/packages/8a/6d/44ad094874c6f1b9c654f8ed939590bdc408349f137f9b98a3a23ccec411/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b", size = 56305206 }, + { url = "https://files.pythonhosted.org/packages/1c/22/2573503d0d4e45673c263a313f79410e110eb562636b0617856fdb2ff5f6/nvidia_curand_cu12-10.3.5.147-py3-none-win_amd64.whl", hash = "sha256:f307cc191f96efe9e8f05a87096abc20d08845a841889ef78cb06924437f6771", size = 55799918 }, ] [[package]] name = "nvidia-cusolver-cu12" -version = "11.4.5.107" +version = "11.6.1.9" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nvidia-cublas-cu12" }, @@ -1037,48 +1046,50 @@ dependencies = [ { name = "nvidia-nvjitlink-cu12" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 }, - { url = "https://files.pythonhosted.org/packages/b8/80/8fca0bf819122a631c3976b6fc517c1b10741b643b94046bd8dd451522c5/nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5", size = 121643081 }, + { url = "https://files.pythonhosted.org/packages/46/6b/a5c33cf16af09166845345275c34ad2190944bcc6026797a39f8e0a282e0/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e", size = 127634111 }, + { url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057 }, + { url = "https://files.pythonhosted.org/packages/f2/be/d435b7b020e854d5d5a682eb5de4328fd62f6182507406f2818280e206e2/nvidia_cusolver_cu12-11.6.1.9-py3-none-win_amd64.whl", hash = "sha256:e77314c9d7b694fcebc84f58989f3aa4fb4cb442f12ca1a9bde50f5e8f6d1b9c", size = 125224015 }, ] [[package]] name = "nvidia-cusparse-cu12" -version = "12.1.0.106" +version = "12.3.1.170" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nvidia-nvjitlink-cu12" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 }, - { url = "https://files.pythonhosted.org/packages/0f/95/48fdbba24c93614d1ecd35bc6bdc6087bd17cbacc3abc4b05a9c2a1ca232/nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a", size = 195414588 }, + { url = "https://files.pythonhosted.org/packages/96/a9/c0d2f83a53d40a4a41be14cea6a0bf9e668ffcf8b004bd65633f433050c0/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3", size = 207381987 }, + { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 }, + { url = "https://files.pythonhosted.org/packages/a2/e0/3155ca539760a8118ec94cc279b34293309bcd14011fc724f87f31988843/nvidia_cusparse_cu12-12.3.1.170-py3-none-win_amd64.whl", hash = "sha256:9bc90fb087bc7b4c15641521f31c0371e9a612fc2ba12c338d3ae032e6b6797f", size = 204684315 }, ] [[package]] name = "nvidia-nccl-cu12" -version = "2.20.5" +version = "2.21.5" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/bb/d09dda47c881f9ff504afd6f9ca4f502ded6d8fc2f572cacc5e39da91c28/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01", size = 176238458 }, - { url = "https://files.pythonhosted.org/packages/4b/2a/0a131f572aa09f741c30ccd45a8e56316e8be8dfc7bc19bf0ab7cfef7b19/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:057f6bf9685f75215d0c53bf3ac4a10b3e6578351de307abad9e18a99182af56", size = 176249402 }, + { url = "https://files.pythonhosted.org/packages/df/99/12cd266d6233f47d00daf3a72739872bdc10267d0383508b0b9c84a18bb6/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0", size = 188654414 }, ] [[package]] name = "nvidia-nvjitlink-cu12" -version = "12.6.68" +version = "12.4.127" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/58/8c/69c9e39cd6bfa813852a94e9bd3c075045e2707d163e9dc2326c82d2c330/nvidia_nvjitlink_cu12-12.6.68-py3-none-manylinux2014_aarch64.whl", hash = "sha256:b3fd0779845f68b92063ab1393abab1ed0a23412fc520df79a8190d098b5cd6b", size = 19253287 }, - { url = "https://files.pythonhosted.org/packages/a8/48/a9775d377cb95585fb188b469387f58ba6738e268de22eae2ad4cedb2c41/nvidia_nvjitlink_cu12-12.6.68-py3-none-manylinux2014_x86_64.whl", hash = "sha256:125a6c2a44e96386dda634e13d944e60b07a0402d391a070e8fb4104b34ea1ab", size = 19725597 }, - { url = "https://files.pythonhosted.org/packages/00/d5/02af3b39427ed71e8c40b6912271499ec186a72405bcb7e4ca26ff70678c/nvidia_nvjitlink_cu12-12.6.68-py3-none-win_amd64.whl", hash = "sha256:a55744c98d70317c5e23db14866a8cc2b733f7324509e941fc96276f9f37801d", size = 161730369 }, + { url = "https://files.pythonhosted.org/packages/02/45/239d52c05074898a80a900f49b1615d81c07fceadd5ad6c4f86a987c0bc4/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83", size = 20552510 }, + { url = "https://files.pythonhosted.org/packages/ff/ff/847841bacfbefc97a00036e0fce5a0f086b640756dc38caea5e1bb002655/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57", size = 21066810 }, + { url = "https://files.pythonhosted.org/packages/81/19/0babc919031bee42620257b9a911c528f05fb2688520dcd9ca59159ffea8/nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1", size = 95336325 }, ] [[package]] name = "nvidia-nvtx-cu12" -version = "12.1.105" +version = "12.4.127" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/da/d3/8057f0587683ed2fcd4dbfbdfdfa807b9160b809976099d36b8f60d08f03/nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5", size = 99138 }, - { url = "https://files.pythonhosted.org/packages/b8/d7/bd7cb2d95ac6ac6e8d05bfa96cdce69619f1ef2808e072919044c2d47a8c/nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82", size = 66307 }, + { url = "https://files.pythonhosted.org/packages/06/39/471f581edbb7804b39e8063d92fc8305bdc7a80ae5c07dbe6ea5c50d14a5/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7959ad635db13edf4fc65c06a6e9f9e55fc2f92596db928d169c0bb031e88ef3", size = 100417 }, + { url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144 }, + { url = "https://files.pythonhosted.org/packages/54/1b/f77674fbb73af98843be25803bbd3b9a4f0a96c75b8d33a2854a5c7d2d77/nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485", size = 66307 }, ] [[package]] @@ -1761,14 +1772,14 @@ wheels = [ [[package]] name = "sympy" -version = "1.13.2" +version = "1.13.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "mpmath" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/94/15/4a041424c7187f41cce678f5a02189b244e9aac61a18b45cd415a3a470f3/sympy-1.13.2.tar.gz", hash = "sha256:401449d84d07be9d0c7a46a64bd54fe097667d5e7181bfe67ec777be9e01cb13", size = 7532926 } +sdist = { url = "https://files.pythonhosted.org/packages/ca/99/5a5b6f19ff9f083671ddf7b9632028436167cd3d33e11015754e41b249a4/sympy-1.13.1.tar.gz", hash = "sha256:9cebf7e04ff162015ce31c9c6c9144daa34a93bd082f54fd8f12deca4f47515f", size = 7533040 } wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/f9/6845bf8fca0eaf847da21c5d5bc6cd92797364662824a11d3f836423a1a5/sympy-1.13.2-py3-none-any.whl", hash = "sha256:c51d75517712f1aed280d4ce58506a4a88d635d6b5dd48b39102a7ae1f3fcfe9", size = 6189289 }, + { url = "https://files.pythonhosted.org/packages/b2/fe/81695a1aa331a842b582453b605175f419fe8540355886031328089d840a/sympy-1.13.1-py3-none-any.whl", hash = "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8", size = 6189177 }, ] [[package]] @@ -1848,7 +1859,7 @@ wheels = [ [[package]] name = "torch" -version = "2.4.1" +version = "2.5.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -1865,24 +1876,27 @@ dependencies = [ { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "setuptools", marker = "python_full_version >= '3.12'" }, { name = "sympy" }, { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and platform_system == 'Linux'" }, { name = "typing-extensions" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/41/05/d540049b1832d1062510efc6829634b7fbef5394c757d8312414fb65a3cb/torch-2.4.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:362f82e23a4cd46341daabb76fba08f04cd646df9bfaf5da50af97cb60ca4971", size = 797072810 }, - { url = "https://files.pythonhosted.org/packages/a0/12/2162df9c47386ae7cedbc938f9703fee4792d93504fab8608d541e71ece3/torch-2.4.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:e8ac1985c3ff0f60d85b991954cfc2cc25f79c84545aead422763148ed2759e3", size = 89699259 }, - { url = "https://files.pythonhosted.org/packages/5d/4c/b2a59ff0e265f5ee154f0d81e948b1518b94f545357731e1a3245ee5d45b/torch-2.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:91e326e2ccfb1496e3bee58f70ef605aeb27bd26be07ba64f37dcaac3d070ada", size = 199433813 }, - { url = "https://files.pythonhosted.org/packages/dc/fb/1333ba666bbd53846638dd75a7a1d4eaf964aff1c482fc046e2311a1b499/torch-2.4.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:d36a8ef100f5bff3e9c3cea934b9e0d7ea277cb8210c7152d34a9a6c5830eadd", size = 62139309 }, - { url = "https://files.pythonhosted.org/packages/ea/ea/4ab009e953bca6ff35ad75b8ab58c0923308636c182c145dc63084f7d136/torch-2.4.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:0b5f88afdfa05a335d80351e3cea57d38e578c8689f751d35e0ff36bce872113", size = 797111232 }, - { url = "https://files.pythonhosted.org/packages/8f/a1/b31f94b4631c1731261db9fdc9a749ef58facc3b76094a6fe974f611f239/torch-2.4.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:ef503165f2341942bfdf2bd520152f19540d0c0e34961232f134dc59ad435be8", size = 89719574 }, - { url = "https://files.pythonhosted.org/packages/5a/6a/775b93d6888c31f1f1fc457e4f5cc89f0984412d5dcdef792b8f2aa6e812/torch-2.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:092e7c2280c860eff762ac08c4bdcd53d701677851670695e0c22d6d345b269c", size = 199436128 }, - { url = "https://files.pythonhosted.org/packages/1f/34/c93873c37f93154d982172755f7e504fdbae6c760499303a3111ce6ce327/torch-2.4.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:ddddbd8b066e743934a4200b3d54267a46db02106876d21cf31f7da7a96f98ea", size = 62145176 }, - { url = "https://files.pythonhosted.org/packages/cc/df/5204a13a7a973c23c7ade615bafb1a3112b5d0ec258d8390f078fa4ab0f7/torch-2.4.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:fdc4fe11db3eb93c1115d3e973a27ac7c1a8318af8934ffa36b0370efe28e042", size = 797019590 }, - { url = "https://files.pythonhosted.org/packages/4f/16/d23a689e5ef8001ed2ace1a3a59f2fda842889b0c3f3877799089925282a/torch-2.4.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:18835374f599207a9e82c262153c20ddf42ea49bc76b6eadad8e5f49729f6e4d", size = 89613802 }, - { url = "https://files.pythonhosted.org/packages/a8/e0/ca8354dfb8d834a76da51b06e8248b70fc182bc163540507919124974bdf/torch-2.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:ebea70ff30544fc021d441ce6b219a88b67524f01170b1c538d7d3ebb5e7f56c", size = 199387694 }, - { url = "https://files.pythonhosted.org/packages/ac/30/8b6f77ea4ce84f015ee024b8dfef0dac289396254e8bfd493906d4cbb848/torch-2.4.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:72b484d5b6cec1a735bf3fa5a1c4883d01748698c5e9cfdbeb4ffab7c7987e0d", size = 62123443 }, + { url = "https://files.pythonhosted.org/packages/2a/ef/834af4a885b31a0b32fff2d80e1e40f771e1566ea8ded55347502440786a/torch-2.5.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:71328e1bbe39d213b8721678f9dcac30dfc452a46d586f1d514a6aa0a99d4744", size = 906446312 }, + { url = "https://files.pythonhosted.org/packages/69/f0/46e74e0d145f43fa506cb336eaefb2d240547e4ce1f496e442711093ab25/torch-2.5.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:34bfa1a852e5714cbfa17f27c49d8ce35e1b7af5608c4bc6e81392c352dbc601", size = 91919522 }, + { url = "https://files.pythonhosted.org/packages/a5/13/1eb674c8efbd04d71e4a157ceba991904f633e009a584dd65dccbafbb648/torch-2.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:32a037bd98a241df6c93e4c789b683335da76a2ac142c0973675b715102dc5fa", size = 203088048 }, + { url = "https://files.pythonhosted.org/packages/a9/9d/e0860474ee0ff8f6ef2c50ec8f71a250f38d78a9b9df9fd241ad3397a65b/torch-2.5.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:23d062bf70776a3d04dbe74db950db2a5245e1ba4f27208a87f0d743b0d06e86", size = 63877046 }, + { url = "https://files.pythonhosted.org/packages/d1/35/e8b2daf02ce933e4518e6f5682c72fd0ed66c15910ea1fb4168f442b71c4/torch-2.5.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:de5b7d6740c4b636ef4db92be922f0edc425b65ed78c5076c43c42d362a45457", size = 906474467 }, + { url = "https://files.pythonhosted.org/packages/40/04/bd91593a4ca178ece93ca55f27e2783aa524aaccbfda66831d59a054c31e/torch-2.5.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:340ce0432cad0d37f5a31be666896e16788f1adf8ad7be481196b503dad675b9", size = 91919450 }, + { url = "https://files.pythonhosted.org/packages/0d/4a/e51420d46cfc90562e85af2fee912237c662ab31140ab179e49bd69401d6/torch-2.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:603c52d2fe06433c18b747d25f5c333f9c1d58615620578c326d66f258686f9a", size = 203098237 }, + { url = "https://files.pythonhosted.org/packages/d0/db/5d9cbfbc7968d79c5c09a0bc0bc3735da079f2fd07cc10498a62b320a480/torch-2.5.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:31f8c39660962f9ae4eeec995e3049b5492eb7360dd4f07377658ef4d728fa4c", size = 63884466 }, + { url = "https://files.pythonhosted.org/packages/8b/5c/36c114d120bfe10f9323ed35061bc5878cc74f3f594003854b0ea298942f/torch-2.5.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:ed231a4b3a5952177fafb661213d690a72caaad97d5824dd4fc17ab9e15cec03", size = 906389343 }, + { url = "https://files.pythonhosted.org/packages/6d/69/d8ada8b6e0a4257556d5b4ddeb4345ea8eeaaef3c98b60d1cca197c7ad8e/torch-2.5.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:3f4b7f10a247e0dcd7ea97dc2d3bfbfc90302ed36d7f3952b0008d0df264e697", size = 91811673 }, + { url = "https://files.pythonhosted.org/packages/5f/ba/607d013b55b9fd805db2a5c2662ec7551f1910b4eef39653eeaba182c5b2/torch-2.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:73e58e78f7d220917c5dbfad1a40e09df9929d3b95d25e57d9f8558f84c9a11c", size = 203046841 }, + { url = "https://files.pythonhosted.org/packages/57/6c/bf52ff061da33deb9f94f4121fde7ff3058812cb7d2036c97bc167793bd1/torch-2.5.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:8c712df61101964eb11910a846514011f0b6f5920c55dbf567bff8a34163d5b1", size = 63858109 }, + { url = "https://files.pythonhosted.org/packages/69/72/20cb30f3b39a9face296491a86adb6ff8f1a47a897e4d14667e6cf89d5c3/torch-2.5.1-cp313-cp313-manylinux1_x86_64.whl", hash = "sha256:9b61edf3b4f6e3b0e0adda8b3960266b9009d02b37555971f4d1c8f7a05afed7", size = 906393265 }, ] [[package]] @@ -1941,15 +1955,15 @@ wheels = [ [[package]] name = "triton" -version = "3.0.0" +version = "3.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock", marker = "python_full_version < '3.13'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/45/27/14cc3101409b9b4b9241d2ba7deaa93535a217a211c86c4cc7151fb12181/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a", size = 209376304 }, - { url = "https://files.pythonhosted.org/packages/33/3e/a2f59384587eff6aeb7d37b6780de7fedd2214935e27520430ca9f5b7975/triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ce8520437c602fb633f1324cc3871c47bee3b67acf9756c1a66309b60e3216c", size = 209438883 }, - { url = "https://files.pythonhosted.org/packages/fe/7b/7757205dee3628f75e7991021d15cd1bd0c9b044ca9affe99b50879fc0e1/triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb", size = 209464695 }, + { url = "https://files.pythonhosted.org/packages/98/29/69aa56dc0b2eb2602b553881e34243475ea2afd9699be042316842788ff5/triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b0dd10a925263abbe9fa37dcde67a5e9b2383fc269fdf59f5657cac38c5d1d8", size = 209460013 }, + { url = "https://files.pythonhosted.org/packages/86/17/d9a5cf4fcf46291856d1e90762e36cbabd2a56c7265da0d1d9508c8e3943/triton-3.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f34f6e7885d1bf0eaaf7ba875a5f0ce6f3c13ba98f9503651c1e6dc6757ed5c", size = 209506424 }, + { url = "https://files.pythonhosted.org/packages/78/eb/65f5ba83c2a123f6498a3097746607e5b2f16add29e36765305e4ac7fdd8/triton-3.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8182f42fd8080a7d39d666814fa36c5e30cc00ea7eeeb1a2983dbb4c99a0fdc", size = 209551444 }, ] [[package]] @@ -2219,7 +2233,7 @@ requires-dist = [ { name = "requests", marker = "extra == 'all'", specifier = ">=2.32.3" }, { name = "setuptools" }, { name = "toposolve" }, - { name = "torch", specifier = "==2.4.1" }, + { name = "torch", specifier = "==2.5.1" }, { name = "torchdata", specifier = ">=0.8.0" }, { name = "transformers", specifier = ">=4.44.2" }, { name = "wandb", marker = "extra == 'all'" },