diff --git a/README.md b/README.md index 6402c47d0..459ba84d1 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,7 @@ NVIDIA RecSys Examples is a collection of optimized recommender models and compo The project includes: - Examples for large-scale HSTU ranking and retrieval models through [TorchRec](https://github.com/pytorch/torchrec) and [Megatron-Core](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core) integration +- Examples for semantic-id based retrieval model through [TorchRec](https://github.com/pytorch/torchrec) and [Megatron-Core](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core) integration - HSTU (Hierarchical Sequential Transduction Unit) attention operator support - Dynamic Embeddings with GPU acceleration @@ -47,6 +48,7 @@ For more detailed release notes, please refer our [releases](https://github.com/ # Get Started The examples we supported: - [HSTU recommender examples](./examples/hstu/README.md) +- [SID based generative recommender examples](./examples/sid_gr/README.md) # Contribution Guidelines Please see our [contributing guidelines](./CONTRIBUTING.md) for details on how to contribute to this project. diff --git a/docker/Dockerfile b/docker/Dockerfile index 285bc63de..9bd01a18e 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -79,5 +79,5 @@ RUN cd /workspace/recsys-examples/corelib/hstu && \ cd hopper && \ HSTU_DISABLE_ARBITRARY=TRUE HSTU_DISABLE_SM8x=TRUE HSTU_DISABLE_LOCAL=TRUE HSTU_DISABLE_RAB=TRUE HSTU_DISABLE_DELTA_Q=FALSE HSTU_DISABLE_DRAB=TRUE pip install . -RUN cd /workspace/recsys-examples/examples/hstu && \ +RUN cd /workspace/recsys-examples/examples/commons && \ TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0" python3 setup.py install diff --git a/examples/hstu/distributed/__init__.py b/examples/commons/distributed/__init__.py similarity index 100% rename from examples/hstu/distributed/__init__.py rename to examples/commons/distributed/__init__.py diff --git a/examples/hstu/distributed/dmp_to_tp.py b/examples/commons/distributed/dmp_to_tp.py similarity index 87% rename from examples/hstu/distributed/dmp_to_tp.py rename to examples/commons/distributed/dmp_to_tp.py index 6a4a144b2..26bbc744b 100644 --- a/examples/hstu/distributed/dmp_to_tp.py +++ b/examples/commons/distributed/dmp_to_tp.py @@ -1,16 +1,15 @@ -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Union import torch import torch.distributed as dist -from dataset.utils import RankingBatch, RetrievalBatch -from megatron.core import parallel_state -from ops.collective_ops import ( +from commons.ops.collective_ops import ( gather_along_first_dim, gatherv_along_first_dim, jagged_tensor_allgather, keyed_jagged_tensor_allgather, ) -from ops.grad_scaling import grad_scaling +from commons.ops.grad_scaling import grad_scaling +from megatron.core import parallel_state from torchrec.sparse.jagged_tensor import JaggedTensor @@ -35,9 +34,7 @@ def jt_dict_grad_scaling_and_allgather( # The features is a kjt, input to embedding module. -def dmp_batch_to_tp( - batch: Union[RetrievalBatch, RankingBatch], exclude_features: bool = True -) -> Union[RetrievalBatch, RankingBatch]: +def dmp_batch_to_tp(batch: Any, exclude_features: bool = True) -> Any: tp_pg = parallel_state.get_tensor_model_parallel_group() tp_size = dist.get_world_size(group=tp_pg) batch_cls = type(batch) diff --git a/examples/hstu/distributed/finalize_model_grads.py b/examples/commons/distributed/finalize_model_grads.py similarity index 100% rename from examples/hstu/distributed/finalize_model_grads.py rename to examples/commons/distributed/finalize_model_grads.py diff --git a/examples/hstu/distributed/sharding.py b/examples/commons/distributed/sharding.py similarity index 98% rename from examples/hstu/distributed/sharding.py rename to examples/commons/distributed/sharding.py index b564988c6..525db239a 100644 --- a/examples/hstu/distributed/sharding.py +++ b/examples/commons/distributed/sharding.py @@ -19,10 +19,11 @@ import torch import torch.distributed as dist import torchrec -from configs.task_config import OptimizerParam # import our own finalize model grads -from distributed.finalize_model_grads import finalize_model_grads +from commons.distributed.finalize_model_grads import finalize_model_grads +from commons.modules.embedding import DataParallelEmbeddingCollection +from commons.optimizer import OptimizerParam from dynamicemb import DynamicEmbTableOptions from dynamicemb.get_planner import get_planner from dynamicemb.planner import ( @@ -40,7 +41,6 @@ from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer from megatron.core.transformer import TransformerConfig from megatron.core.transformer.module import Float16Module -from modules.embedding import DataParallelEmbeddingCollection from torch import distributed as dist from torch.distributed.optim import ( _apply_optimizer_in_backward as apply_optimizer_in_backward, diff --git a/examples/commons/modules/__init__.py b/examples/commons/modules/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/hstu/modules/embedding.py b/examples/commons/modules/embedding.py similarity index 90% rename from examples/hstu/modules/embedding.py rename to examples/commons/modules/embedding.py index dee397bf6..3b5cd9a81 100644 --- a/examples/hstu/modules/embedding.py +++ b/examples/commons/modules/embedding.py @@ -14,6 +14,7 @@ # limitations under the License. import copy import os +from dataclasses import dataclass # pyre-strict from typing import Any, Dict, Iterator, List, Optional, Tuple @@ -23,7 +24,6 @@ import torch.fx import torch.nn as nn from commons.utils.nvtx_op import output_nvtx_hook, register_setter_and_getter_for_nvtx -from configs.task_config import ShardedEmbeddingConfig from dynamicemb.planner import ( DynamicEmbeddingShardingPlanner as DynamicEmbeddingShardingPlanner, ) @@ -52,6 +52,40 @@ from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor +@dataclass +class ShardedEmbeddingConfig: + """ + Configuration for sharded embeddings with sharding type. Inherits from BaseShardedEmbeddingConfig. + + Args: + config (EmbeddingConfig): The embedding configuration. + sharding_type (str): The type of sharding, ``'data_parallel'`` | ``'model_parallel'``. + """ + + """ + Base configuration for sharded embeddings. + + Args: + feature_names (List[str]): The name of the features in this embedding. + table_name (str): The name of the table. + vocab_size (int): The size of the vocabulary. + dim (int): The dimension size of the embeddings. + sharding_type (str): The type of sharding, ``'data_parallel'`` | ``'model_parallel'``. + """ + + feature_names: List[str] + table_name: str + vocab_size: int + dim: int + sharding_type: str + + def __post_init__(self): + assert self.sharding_type in [ + "data_parallel", + "model_parallel", + ], "sharding type should be data_parallel or model_parallel" + + def create_data_parallel_sharding_infos_by_sharding( module: EmbeddingCollectionInterface, table_name_to_parameter_sharding: Dict[str, ParameterSharding], @@ -132,6 +166,7 @@ class DataParallelEmbeddingCollection(torch.nn.Module): """ Sharded implementation of `EmbeddingCollection`. This is part of the public API to allow for manual data dist pipelining. + We re-implement the DP embedding so that it can be wrapped by Megatron DDP. """ def __init__( @@ -354,14 +389,19 @@ def forward(self, kjt: KeyedJaggedTensor) -> Dict[str, JaggedTensor]: Returns: `Dict[str, JaggedTensor ]`: The output embeddings. """ - mp_embeddings_awaitables = self._model_parallel_embedding_collection(kjt) + assert not ( + self._model_parallel_embedding_collection is None + and self._data_parallel_embedding_collection is None + ), "either model_parallel_embedding_collection or data_parallel_embedding_collection must be not None" + embeddings: Dict[str, JaggedTensor] = {} + if self._model_parallel_embedding_collection is not None: + mp_embeddings_awaitables = self._model_parallel_embedding_collection(kjt) + embeddings = {**embeddings, **(mp_embeddings_awaitables.wait())} if self._data_parallel_embedding_collection is not None: with torch.cuda.stream(self._side_stream): dp_embeddings = self._data_parallel_embedding_collection(kjt) torch.cuda.current_stream().wait_stream(self._side_stream) - embeddings = {**mp_embeddings_awaitables.wait(), **dp_embeddings} - else: - embeddings = mp_embeddings_awaitables.wait() + embeddings = {**embeddings, **dp_embeddings} return embeddings def export_local_embedding(self, table_name: str) -> Tuple[np.ndarray, np.ndarray]: @@ -381,7 +421,7 @@ def export_local_embedding(self, table_name: str) -> Tuple[np.ndarray, np.ndarra Example: >>> # assume we have 2 ranks >>> import torch - >>> from modules.embedding import ShardedEmbedding + >>> from commons.modules.embedding import ShardedEmbedding >>> from configs.task_config import ShardedEmbeddingConfig >>> from commons.utils.initialize as init >>> from commons.utils.logger import print_rank_0 diff --git a/examples/hstu/ops/collective_ops.py b/examples/commons/ops/collective_ops.py similarity index 99% rename from examples/hstu/ops/collective_ops.py rename to examples/commons/ops/collective_ops.py index ec5c50a52..0d5e1331b 100644 --- a/examples/hstu/ops/collective_ops.py +++ b/examples/commons/ops/collective_ops.py @@ -16,7 +16,7 @@ import torch import torch.distributed as dist -from ops.length_to_offsets import length_to_complete_offsets +from commons.ops.length_to_offsets import length_to_complete_offsets from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor diff --git a/examples/hstu/ops/cuda_ops/JaggedTensorOpFunction.py b/examples/commons/ops/cuda_ops/JaggedTensorOpFunction.py similarity index 99% rename from examples/hstu/ops/cuda_ops/JaggedTensorOpFunction.py rename to examples/commons/ops/cuda_ops/JaggedTensorOpFunction.py index d4ec5a26e..ce825a08b 100644 --- a/examples/hstu/ops/cuda_ops/JaggedTensorOpFunction.py +++ b/examples/commons/ops/cuda_ops/JaggedTensorOpFunction.py @@ -2,7 +2,7 @@ import hstu_cuda_ops import torch -from ops.length_to_offsets import length_to_complete_offsets +from commons.ops.length_to_offsets import length_to_complete_offsets class _JaggedTensorOpFunction(torch.autograd.Function): diff --git a/examples/hstu/ops/cuda_ops/csrc/jagged_tensor_op_cuda.cpp b/examples/commons/ops/cuda_ops/csrc/jagged_tensor_op_cuda.cpp similarity index 100% rename from examples/hstu/ops/cuda_ops/csrc/jagged_tensor_op_cuda.cpp rename to examples/commons/ops/cuda_ops/csrc/jagged_tensor_op_cuda.cpp diff --git a/examples/hstu/ops/cuda_ops/csrc/jagged_tensor_op_kernel.cu b/examples/commons/ops/cuda_ops/csrc/jagged_tensor_op_kernel.cu similarity index 100% rename from examples/hstu/ops/cuda_ops/csrc/jagged_tensor_op_kernel.cu rename to examples/commons/ops/cuda_ops/csrc/jagged_tensor_op_kernel.cu diff --git a/examples/hstu/ops/cuda_ops/csrc/paged_kvcache_ops_cuda.cpp b/examples/commons/ops/cuda_ops/csrc/paged_kvcache_ops_cuda.cpp similarity index 100% rename from examples/hstu/ops/cuda_ops/csrc/paged_kvcache_ops_cuda.cpp rename to examples/commons/ops/cuda_ops/csrc/paged_kvcache_ops_cuda.cpp diff --git a/examples/hstu/ops/cuda_ops/csrc/paged_kvcache_ops_kernel.cu b/examples/commons/ops/cuda_ops/csrc/paged_kvcache_ops_kernel.cu similarity index 100% rename from examples/hstu/ops/cuda_ops/csrc/paged_kvcache_ops_kernel.cu rename to examples/commons/ops/cuda_ops/csrc/paged_kvcache_ops_kernel.cu diff --git a/examples/hstu/ops/cuda_ops/csrc/vec_dtypes.cuh b/examples/commons/ops/cuda_ops/csrc/vec_dtypes.cuh similarity index 100% rename from examples/hstu/ops/cuda_ops/csrc/vec_dtypes.cuh rename to examples/commons/ops/cuda_ops/csrc/vec_dtypes.cuh diff --git a/examples/hstu/ops/grad_scaling.py b/examples/commons/ops/grad_scaling.py similarity index 100% rename from examples/hstu/ops/grad_scaling.py rename to examples/commons/ops/grad_scaling.py diff --git a/examples/hstu/ops/length_to_offsets.py b/examples/commons/ops/length_to_offsets.py similarity index 100% rename from examples/hstu/ops/length_to_offsets.py rename to examples/commons/ops/length_to_offsets.py diff --git a/examples/hstu/ops/triton_ops/common.py b/examples/commons/ops/triton_ops/common.py similarity index 83% rename from examples/hstu/ops/triton_ops/common.py rename to examples/commons/ops/triton_ops/common.py index 9940d77ab..2a27f48da 100644 --- a/examples/hstu/ops/triton_ops/common.py +++ b/examples/commons/ops/triton_ops/common.py @@ -33,7 +33,7 @@ import dataclasses from dataclasses import dataclass from enum import Enum, unique -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Tuple, Union import torch @@ -129,55 +129,6 @@ class HammerKernel(Enum): # type: ignore[no-redef] TRITON_CC = "TRITON_CC" -class GRModuleBase(torch.nn.Module): - _is_inference: bool - _use_triton_cc: bool - _custom_kernel: bool - _hammer_kernel: Optional[HammerKernel] = None - - def __init__( - self, - is_inference: bool, - use_triton_cc: bool = True, - custom_kernel: bool = True, - hammer_kernel: Optional[HammerKernel] = None, - ) -> None: - super().__init__() - self._is_inference = is_inference - self._use_triton_cc = use_triton_cc - self._custom_kernel = custom_kernel - self._hammer_kernel = hammer_kernel - - def hammer_kernel(self) -> HammerKernel: - kernel = self._hammer_kernel - if kernel is not None: - return kernel - if self._custom_kernel: - if self._is_inference and self._use_triton_cc: - return HammerKernel.TRITON_CC - else: - return HammerKernel.TRITON - else: - return HammerKernel.PYTORCH - - # pyre-ignore[2] - def recursive_setattr(self, name: str, value: Any) -> None: - for _, module in self.named_modules(): - if hasattr(module, name): - setattr(module, name, value) - - @property - def predict_mode(self) -> bool: - return self._is_inference - - @property - def eval_mode(self) -> bool: - return (not self._is_inference) and (not self.training) - - @property - def train_mode(self) -> bool: - return (not self._is_inference) and self.training - def generate_sparse_seq_len( size: int, diff --git a/examples/commons/ops/triton_ops/triton_jagged.py b/examples/commons/ops/triton_ops/triton_jagged.py new file mode 100644 index 000000000..026eb327b --- /dev/null +++ b/examples/commons/ops/triton_ops/triton_jagged.py @@ -0,0 +1,1345 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# type: ignore +#!/usr/bin/env python3 + + +from typing import List, Optional, Tuple + +import torch + +# @manual=//triton:triton +import triton + +# @manual=//triton:triton +import triton.language as tl +from commons.ops.triton_ops.common import ( + autotune_max_seq_len, + switch_to_contiguous_if_needed, + triton_autotune, +) + + +def _get_bmm_configs() -> List[triton.Config]: + configs = [] + for BLOCK_M in [64, 128]: + for BLOCK_N in [64, 128]: + for BLOCK_K in [32, 64]: + for num_stages in [2, 3]: + for num_warps in [4, 8]: + configs.append( + triton.Config( + { + "BLOCK_M": BLOCK_M, + "BLOCK_N": BLOCK_N, + "BLOCK_K": BLOCK_K, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + return configs + + +@triton_autotune( + configs=_get_bmm_configs(), + key=["AUTOTUNE_MAX_SEQ_LEN", "N", "K"], +) +@triton.jit +def jagged_dense_bmm_broadcast_add_kernel( + seq_offsets, + Jagged, + Dense, + Bias, + Out, + AUTOTUNE_MAX_SEQ_LEN, + N, + K, + stride_jm, + stride_db, + stride_dk, + stride_dn, + stride_bias_b, + stride_om, + HAS_BIAS: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + """ + Computing bmm Out = Jagged x Dense + Bias + M is the jagged dimension + Jagged has shape (sum_B(M_i), K), Dense has shape (B, K, N), Bias has shape (B, N), and Out has shape (sum_B(M_i), N) + """ + + off_n = tl.program_id(0) + off_m = tl.program_id(1) + off_b = tl.program_id(2) + + seq_start = tl.load(seq_offsets + off_b).to(tl.int64) + seq_end = tl.load(seq_offsets + off_b + 1) + seq_len = seq_end - seq_start + start_m = off_m * BLOCK_M + start_n = off_n * BLOCK_N + if start_m >= seq_len: + return + + Jagged += seq_start * stride_jm + Dense += off_b.to(tl.int64) * stride_db + Out += seq_start * stride_om + + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + jg_ptrs = Jagged + offs_m[:, None] * stride_jm + offs_k[None, :] + dn_ptrs = Dense + offs_k[:, None] * stride_dk + offs_n[None, :] * stride_dn + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, K, BLOCK_K): + jg = tl.load( + jg_ptrs, + # pyre-fixme[16]: `int` has no attribute `__getitem__`. + mask=(offs_m[:, None] < seq_len) and ((k + offs_k)[None, :] < K), + other=0.0, + ) + dn = tl.load( + dn_ptrs, + mask=((k + offs_k)[:, None] < K) and (offs_n[None, :] < N), + other=0.0, + ) + accumulator += tl.dot(jg, dn, allow_tf32=ALLOW_TF32) + jg_ptrs += BLOCK_K + dn_ptrs += BLOCK_K * stride_dk + + if HAS_BIAS: + bias_ptrs = Bias + off_b * stride_bias_b + offs_n + bias = tl.load(bias_ptrs, mask=offs_n < N) + accumulator += bias[None, :].to(tl.float32) + + out = accumulator.to(Out.dtype.element_ty) + + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_n = start_n + tl.arange(0, BLOCK_N) + out_ptrs = Out + offs_m[:, None] * stride_om + offs_n[None, :] + tl.store(out_ptrs, out, mask=(offs_m[:, None] < seq_len) & (offs_n[None, :] < N)) + + +@triton_autotune( + configs=_get_bmm_configs(), + key=["M", "N", "AUTOTUNE_MAX_SEQ_LEN"], +) +@triton.jit +def _jagged_jagged_bmm_reduce_sum( + seq_offsets, + JaggedA, + JaggedB, + Out, + ReduceOut, + M, + N, + AUTOTUNE_MAX_SEQ_LEN, + stride_ak, + stride_bk, + stride_ob, + stride_om, + stride_on, + stride_orb, + stride_orn, + REDUCE_JAGGEDB: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + """ + Computing bmm Out = Jagged x Jagged + K is the jagged dimension + JaggedA has shape (sum_B(K_i), M), JaggedB has shape (sum_B(K_i), N), and Out has shape (B, M, N) + """ + + off_b = tl.program_id(0) + off_m = tl.program_id(1) + off_n = tl.program_id(2) + + seq_start = tl.load(seq_offsets + off_b).to(tl.int64) + seq_end = tl.load(seq_offsets + off_b + 1) + seq_len = seq_end - seq_start + + start_m = off_m * BLOCK_M + start_n = off_n * BLOCK_N + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + Out += off_b.to(tl.int64) * stride_ob + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_n = start_n + tl.arange(0, BLOCK_N) + out_ptrs = Out + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + if REDUCE_JAGGEDB: + out_reduce_ptrs = ReduceOut + off_b * stride_orb + offs_n * stride_orn + acc_reduce = tl.zeros((BLOCK_N,), dtype=tl.float32) + if seq_len == 0: + out = accumulator.to(Out.dtype.element_ty) + tl.store(out_ptrs, out, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N)) + if REDUCE_JAGGEDB: + if off_m == 0: + tl.store( + out_reduce_ptrs, # pyre-ignore [61] + acc_reduce.to(ReduceOut.dtype.element_ty), + mask=(offs_n < N), + ) + return + + JaggedA += seq_start * stride_ak + JaggedB += seq_start * stride_bk + offs_k = tl.arange(0, BLOCK_K) + jg_a_ptrs = JaggedA + offs_k[None, :] * stride_ak + offs_m[:, None] + jg_b_ptrs = JaggedB + offs_k[:, None] * stride_bk + offs_n[None, :] + + for k in range(0, seq_len, BLOCK_K): + jg_a = tl.load( + jg_a_ptrs, + # pyre-fixme[16]: `int` has no attribute `__getitem__`. + mask=(offs_m[:, None] < M) and ((k + offs_k)[None, :] < seq_len), + other=0.0, + ) + jg_b = tl.load( + jg_b_ptrs, + mask=(offs_n[None, :] < N) and ((k + offs_k)[:, None] < seq_len), + other=0.0, + ) + + accumulator += tl.dot(jg_a, jg_b, allow_tf32=ALLOW_TF32) + if REDUCE_JAGGEDB: + if off_m == 0: + acc_reduce += tl.sum(jg_b.to(tl.float32), axis=0) + + jg_a_ptrs += BLOCK_K * stride_ak + jg_b_ptrs += BLOCK_K * stride_bk + + out = accumulator.to(Out.dtype.element_ty) + tl.store(out_ptrs, out, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N)) + if REDUCE_JAGGEDB: + if off_m == 0: + tl.store( + out_reduce_ptrs, # pyre-ignore [61] + acc_reduce.to(ReduceOut.dtype.element_ty), + mask=(offs_n < N), + ) + + +class _JaggedDenseBmmFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, + ): + jagged = switch_to_contiguous_if_needed(jagged) + L, D = jagged.shape + B, _, K = dense.shape + bmm_out = torch.empty((L, K), dtype=jagged.dtype, device=jagged.device) + + grid = lambda meta: ( # noqa E731 + triton.cdiv(K, meta["BLOCK_N"]), + triton.cdiv(max_seq_len, meta["BLOCK_M"]), + B, + ) + + jagged_dense_bmm_broadcast_add_kernel[grid]( + seq_offsets=seq_offsets, + Jagged=jagged, + Dense=dense, + Bias=None, + Out=bmm_out, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(max_seq_len), + N=K, + K=D, + stride_jm=jagged.stride(0), + stride_db=dense.stride(0), + stride_dk=dense.stride(1), + stride_dn=dense.stride(2), + stride_bias_b=0, + stride_om=bmm_out.stride(0), + HAS_BIAS=False, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + ) + + ctx.save_for_backward(seq_offsets, jagged, dense) + ctx.B = B + ctx.max_seq_len = max_seq_len + ctx.K = K + ctx.D = D + return bmm_out + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, d_bmm_out: torch.Tensor + ) -> Tuple[None, None, torch.Tensor, torch.Tensor]: + seq_offsets, jagged, dense = ctx.saved_tensors + d_jagged = torch.empty_like(jagged) + d_dense = torch.empty_like(dense) + + grid = lambda meta: ( # noqa E731 + triton.cdiv(ctx.D, meta["BLOCK_N"]), + triton.cdiv(ctx.max_seq_len, meta["BLOCK_M"]), + ctx.B, + ) + jagged_dense_bmm_broadcast_add_kernel[grid]( + seq_offsets=seq_offsets, + Jagged=d_bmm_out, + Dense=dense, + Bias=None, + Out=d_jagged, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(ctx.max_seq_len), + N=ctx.D, + K=ctx.K, + stride_jm=d_bmm_out.stride(0), + stride_db=dense.stride(0), + stride_dk=dense.stride(2), + stride_dn=dense.stride(1), + stride_bias_b=0, + stride_om=d_jagged.stride(0), + HAS_BIAS=False, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + ) + + grid = lambda meta: ( # noqa E731 + ctx.B, + triton.cdiv(ctx.D, meta["BLOCK_M"]), + triton.cdiv(ctx.K, meta["BLOCK_N"]), + ) + _jagged_jagged_bmm_reduce_sum[grid]( + seq_offsets=seq_offsets, + JaggedA=jagged, + JaggedB=d_bmm_out, + Out=d_dense, + ReduceOut=None, + M=ctx.D, + N=ctx.K, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(ctx.max_seq_len), + stride_ak=jagged.stride(0), + stride_bk=d_bmm_out.stride(0), + stride_ob=d_dense.stride(0), + stride_om=d_dense.stride(1), + stride_on=d_dense.stride(2), + stride_orb=0, + stride_orn=0, + REDUCE_JAGGEDB=False, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + ) + + return None, None, d_jagged, d_dense + + +def _get_jagged_dense_broadcast_add_configs() -> List[triton.Config]: + configs = [] + for BLOCK_N in [16, 32, 64]: + for num_stages in [1, 2]: + for num_warps in [2, 4, 8]: + configs.append( + triton.Config( + { + "BLOCK_N": BLOCK_N, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + return configs + + +@triton_autotune( + configs=_get_jagged_dense_broadcast_add_configs(), + key=["AUTOTUNE_MAX_SEQ_LEN"], +) +@triton.jit +def jagged_dense_broadcast_add_kernel( + seq_offsets, + Jagged, + Dense, + Out, + AUTOTUNE_MAX_SEQ_LEN, + D, + stride_jn, + stride_db, + stride_on, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + """ + Computing Out = Jagged + Dense + JaggedA has shape (sum_B(N_i), D), Dense has shape (B, D), and Out has shape (sum_B(N_i), D) + """ + + off_b = tl.program_id(0) + off_n = tl.program_id(1) + seq_start = tl.load(seq_offsets + off_b) + seq_end = tl.load(seq_offsets + off_b + 1) + seq_len = seq_end - seq_start + start_n = off_n * BLOCK_N + if start_n >= seq_len: + return + Jagged += seq_start * stride_jn + Dense += off_b * stride_db + Out += seq_start * stride_on + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + jagged_ptrs = Jagged + offs_n[:, None] * stride_jn + offs_d[None, :] + dense_ptrs = Dense + offs_d + out_ptrs = Out + offs_n[:, None] * stride_jn + offs_d[None, :] + for d in range(0, D, BLOCK_D): + jg = tl.load( + jagged_ptrs, + # pyre-fixme[16]: `int` has no attribute `__getitem__`. + mask=(offs_n[:, None] < seq_len) and (d + offs_d)[None, :] < D, + ) + dn = tl.load(dense_ptrs, mask=d + offs_d < D) + out = jg + dn[None, :] + tl.store( + out_ptrs, + out, + mask=(offs_n[:, None] < seq_len) and (d + offs_d)[None, :] < D, + ) + dense_ptrs += BLOCK_D + jagged_ptrs += BLOCK_D + out_ptrs += BLOCK_D + + +@triton.jit +def jagged_reduce_sum( + seq_offsets, + Jagged, + Out, + D, + stride_jn, + stride_ob, + BLOCK_D: tl.constexpr, +): + """ + Computing Out = Jagged + Dense + JaggedA has shape (sum_B(N_i), D), Dense has shape (B, D), and Out has shape (sum_B(N_i), D) + """ + off_b = tl.program_id(0) + off_d = tl.program_id(1) * BLOCK_D + seq_start = tl.load(seq_offsets + off_b) + seq_end = tl.load(seq_offsets + off_b + 1) + seq_len = seq_end - seq_start + Jagged += seq_start * stride_jn + Out += off_b * stride_ob + offs_d = off_d + tl.arange(0, BLOCK_D) + jagged_ptrs = Jagged + offs_d + out_ptrs = Out + offs_d + accumulator = tl.zeros((BLOCK_D,), dtype=tl.float32) + for _ in range(0, seq_len): + jg = tl.load( + jagged_ptrs, + mask=offs_d < D, + ) + accumulator += jg + jagged_ptrs += stride_jn + out = accumulator.to(Out.dtype.element_ty) + tl.store( + out_ptrs, + out, + mask=offs_d < D, + ) + + +class _JaggedDenseBroadcastAddFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, + ): + jagged = switch_to_contiguous_if_needed(jagged) + dense = switch_to_contiguous_if_needed(dense) + L, D = jagged.shape + B, _ = dense.shape + out = torch.empty_like(jagged) + + grid = lambda meta: ( # noqa E731 + B, + triton.cdiv(max_seq_len, meta["BLOCK_N"]), + ) + BLOCK_D = triton.next_power_of_2(D) if D < 64 else 64 + jagged_dense_broadcast_add_kernel[grid]( + seq_offsets=seq_offsets, + Jagged=jagged, + Dense=dense, + Out=out, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(max_seq_len), + D=D, + stride_jn=jagged.stride(0), + stride_db=dense.stride(0), + stride_on=out.stride(0), + BLOCK_D=BLOCK_D, + ) + + ctx.save_for_backward(seq_offsets) + ctx.max_seq_len = max_seq_len + ctx.B = B + ctx.D = D + return out + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, d_out: torch.Tensor + ) -> Tuple[None, None, torch.Tensor, torch.Tensor]: + seq_offsets = ctx.saved_tensors[0] + d_dense = torch.empty((ctx.B, ctx.D), device=d_out.device, dtype=d_out.dtype) + BLOCK_D = triton.next_power_of_2(ctx.D) if ctx.D < 64 else 64 + jagged_reduce_sum[(ctx.B, triton.cdiv(ctx.D, BLOCK_D))]( + seq_offsets=seq_offsets, + Jagged=d_out, + Out=d_dense, + D=ctx.D, + stride_jn=d_out.stride(0), + stride_ob=d_dense.stride(0), + BLOCK_D=BLOCK_D, + ) + return None, None, d_out, d_dense + + +class _JaggedDenseBmmBroadcastAddFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, + bias: torch.Tensor, + ): + jagged = switch_to_contiguous_if_needed(jagged) + bias = switch_to_contiguous_if_needed(bias) + L, K = jagged.shape + B, _, N = dense.shape + out = torch.empty((L, N), dtype=jagged.dtype, device=jagged.device) + + grid = lambda meta: ( # noqa E731 + triton.cdiv(N, meta["BLOCK_N"]), + triton.cdiv(max_seq_len, meta["BLOCK_M"]), + B, + ) + + jagged_dense_bmm_broadcast_add_kernel[grid]( + seq_offsets=seq_offsets, + Jagged=jagged, + Dense=dense, + Bias=bias, + Out=out, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(max_seq_len), + N=N, + K=K, + stride_jm=jagged.stride(0), + stride_db=dense.stride(0), + stride_dk=dense.stride(1), + stride_dn=dense.stride(2), + stride_bias_b=bias.stride(0), + stride_om=out.stride(0), + HAS_BIAS=True, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + ) + + ctx.save_for_backward(seq_offsets, jagged, dense) + ctx.B = B + ctx.max_seq_len = max_seq_len + ctx.K = K + ctx.N = N + return out + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, d_out: torch.Tensor + ) -> Tuple[None, None, torch.Tensor, torch.Tensor, torch.Tensor]: + seq_offsets, jagged, dense = ctx.saved_tensors + d_jagged = torch.empty_like(jagged) + d_dense = torch.empty_like(dense) + d_bias = torch.empty((ctx.B, ctx.N), device=d_out.device, dtype=d_out.dtype) + + grid = lambda meta: ( # noqa E731 + triton.cdiv(ctx.K, meta["BLOCK_N"]), + triton.cdiv(ctx.max_seq_len, meta["BLOCK_M"]), + ctx.B, + ) + jagged_dense_bmm_broadcast_add_kernel[grid]( + seq_offsets=seq_offsets, + Jagged=d_out, + Dense=dense, + Bias=None, + Out=d_jagged, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(ctx.max_seq_len), + N=ctx.K, + K=ctx.N, + stride_jm=d_out.stride(0), + stride_db=dense.stride(0), + stride_dk=dense.stride(2), + stride_dn=dense.stride(1), + stride_bias_b=0, + stride_om=d_jagged.stride(0), + HAS_BIAS=False, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + ) + + grid = lambda meta: ( # noqa E731 + ctx.B, + triton.cdiv(ctx.K, meta["BLOCK_M"]), + triton.cdiv(ctx.N, meta["BLOCK_N"]), + ) + _jagged_jagged_bmm_reduce_sum[grid]( + seq_offsets=seq_offsets, + JaggedA=jagged, + JaggedB=d_out, + Out=d_dense, + ReduceOut=d_bias, + M=ctx.K, + N=ctx.N, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(ctx.max_seq_len), + stride_ak=jagged.stride(0), + stride_bk=d_out.stride(0), + stride_ob=d_dense.stride(0), + stride_om=d_dense.stride(1), + stride_on=d_dense.stride(2), + stride_orb=d_bias.stride(0), + stride_orn=d_bias.stride(1), + REDUCE_JAGGEDB=True, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + ) + + return None, None, d_jagged, d_dense, d_bias + + +@triton.jit +def concat_2D_jagged_w_prefix( + OffsetsA, + ValuesA, + OffsetsB, + ValuesB, + DenseSize, + Out, + D, + stride_ad, + stride_bd, + stride_dense_batch, + stride_od, + n_prefix_from_B, # nonzero is not supported when IS_REPLACE=True + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + IS_REPLACE: tl.constexpr, +): + off_z = tl.program_id(1) + off_n = tl.program_id(0) + if IS_DENSE_A: + seq_start_a = off_z * DenseSize + seq_len_a = DenseSize + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_len_b = seq_end_b - seq_start_b + elif IS_DENSE_B: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + seq_start_b = off_z * DenseSize + seq_len_b = DenseSize + else: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_len_b = seq_end_b - seq_start_b + + if IS_REPLACE: + seq_len = seq_len_a + else: + seq_len = seq_len_a + seq_len_b + if off_n >= seq_len: + return + + offs_d = tl.arange(0, BLOCK_D) + if IS_REPLACE: + out_seq_start = seq_start_a + off_n + out_seq_b_start = seq_len_a - seq_len_b + else: + out_seq_start = seq_start_a + seq_start_b + off_n + out_seq_b_start = seq_len_a + n_prefix_from_B + + out_ptrs = Out + out_seq_start.to(tl.int64) * stride_od + offs_d + if off_n < out_seq_b_start and off_n >= n_prefix_from_B: + off_a = off_n - n_prefix_from_B + if IS_DENSE_A: + in_ptrs = ( + ValuesA + + off_a.to(tl.int64) * stride_ad + + off_z.to(tl.int64) * stride_dense_batch + + offs_d + ) + else: + in_ptrs = ValuesA + (off_a + seq_start_a).to(tl.int64) * stride_ad + offs_d + else: + off_b = off_n - out_seq_b_start + n_prefix_from_B + if off_n < n_prefix_from_B: + off_b += out_seq_b_start - n_prefix_from_B + if IS_DENSE_B: + in_ptrs = ( + ValuesB + + off_b.to(tl.int64) * stride_bd + + off_z.to(tl.int64) * stride_dense_batch + + offs_d + ) + else: + in_ptrs = ValuesB + (off_b + seq_start_b).to(tl.int64) * stride_bd + offs_d + v = tl.load(in_ptrs, mask=offs_d < D) + tl.store(out_ptrs, v, mask=offs_d < D) + + +@triton.jit +def concat_2D_jagged( + OffsetsA, + ValuesA, + OffsetsB, + ValuesB, + DenseSize, + Out, + D, + stride_ad, + stride_bd, + stride_dense_batch, + stride_od, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + IS_REPLACE: tl.constexpr, +): + concat_2D_jagged_w_prefix( + OffsetsA, + ValuesA, + OffsetsB, + ValuesB, + DenseSize, + Out, + D, + stride_ad, + stride_bd, + stride_dense_batch, + stride_od, + 0, + IS_DENSE_A, + IS_DENSE_B, + BLOCK_D, + IS_REPLACE, + ) + + +@triton.jit +def concat_2D_jagged_jagged_w_prefix( + OffsetsA, + ValuesA, + OffsetsB, + ValuesB, + Out, + D, + stride_ad, + stride_bd, + stride_od, + n_prefix_from_B, + BLOCK_D: tl.constexpr, +): + concat_2D_jagged_w_prefix( + OffsetsA, + ValuesA, + OffsetsB, + ValuesB, + 0, + Out, + D, + stride_ad, + stride_bd, + 0, + stride_od, + n_prefix_from_B, + IS_DENSE_A=False, + IS_DENSE_B=False, + BLOCK_D=BLOCK_D, + IS_REPLACE=False, + ) + + +@triton.jit +def split_2D_jagged_w_prefix( + JaggedIn, + DenseSize, + OffsetsA, + OffsetsB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + n_prefix_to_B, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + IS_REPLACE: tl.constexpr, +): + off_z = tl.program_id(1) + off_n = tl.program_id(0) + if IS_DENSE_A: + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_start_a = off_z * DenseSize + seq_len_a = DenseSize + seq_len_b = seq_end_b - seq_start_b + elif IS_DENSE_B: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + seq_start_b = off_z * DenseSize + seq_len_b = DenseSize + else: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_len_b = seq_end_b - seq_start_b + if IS_REPLACE: + seq_len = seq_len_a + else: + seq_len = seq_len_a + seq_len_b + if off_n >= seq_len: + return + + if IS_REPLACE: + seq_start = seq_start_a + out_seq_b_start = seq_len_a - seq_len_b + else: + seq_start = seq_start_a + seq_start_b + out_seq_b_start = seq_len_a + n_prefix_to_B + + offs_d = tl.arange(0, BLOCK_D) + in_ptrs = JaggedIn + (seq_start + off_n).to(tl.int64) * stride_id + offs_d + if off_n < out_seq_b_start and off_n >= n_prefix_to_B: + off_a = off_n - n_prefix_to_B + out_ptrs = OutA + (off_a + seq_start_a).to(tl.int64) * stride_ad + offs_d + else: + off_b = off_n - out_seq_b_start + n_prefix_to_B + if off_n < n_prefix_to_B: + off_b += out_seq_b_start - n_prefix_to_B + out_ptrs = OutB + (off_b + seq_start_b).to(tl.int64) * stride_bd + offs_d + v = tl.load(in_ptrs, mask=offs_d < D) + tl.store(out_ptrs, v, mask=offs_d < D) + + +@triton.jit +def split_2D_jagged( + JaggedIn, + DenseSize, + OffsetsA, + OffsetsB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + IS_REPLACE: tl.constexpr, +): + split_2D_jagged_w_prefix( + JaggedIn, + DenseSize, + OffsetsA, + OffsetsB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + 0, + IS_DENSE_A, + IS_DENSE_B, + BLOCK_D, + IS_REPLACE, + ) + + +@triton.jit +def split_2D_jagged_jagged_w_prefix( + JaggedIn, + OffsetsA, + OffsetsB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + n_prefix_to_B, + BLOCK_D: tl.constexpr, +): + split_2D_jagged_w_prefix( + JaggedIn, + 0, + OffsetsA, + OffsetsB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + n_prefix_to_B, + IS_DENSE_A=False, + IS_DENSE_B=False, + BLOCK_D=BLOCK_D, + IS_REPLACE=False, + ) + + +class _Concat2DJaggedFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + max_seq_len: int, + values_a: torch.Tensor, + values_b: torch.Tensor, + offsets_a: Optional[torch.Tensor] = None, + offsets_b: Optional[torch.Tensor] = None, + is_replace: bool = False, + n_prefix_from_right: int = 0, + ): + values_a = switch_to_contiguous_if_needed(values_a) + values_b = switch_to_contiguous_if_needed(values_b) + is_dense_a = offsets_a is None + is_dense_b = offsets_b is None + dense_size: int = 0 + if is_dense_a: + assert offsets_b is not None + B, dense_size, D = values_a.shape + seq_len_a = dense_size * B + seq_len_b, _ = values_b.shape + device = values_b.device + dtype = values_b.dtype + stride_dense_batch = values_a.stride(0) + elif is_dense_b: + assert offsets_a is not None + B, dense_size, D = values_b.shape + seq_len_a, _ = values_a.shape + seq_len_b = dense_size * B + device = values_a.device + dtype = values_a.dtype + stride_dense_batch = values_b.stride(0) + else: + assert offsets_a is not None and offsets_b is not None + B = offsets_a.shape[0] - 1 + seq_len_a, D = values_a.shape + seq_len_b, _ = values_b.shape + device = values_a.device + dtype = values_a.dtype + stride_dense_batch = 0 + + BLOCK_D = triton.next_power_of_2(D) + if is_replace: + values_out = torch.empty_like(values_a) + else: + values_out = torch.empty( + (seq_len_a + seq_len_b, D), device=device, dtype=dtype + ) + if n_prefix_from_right == 0: + concat_2D_jagged[(max_seq_len, B)]( + OffsetsA=offsets_a, + ValuesA=values_a, + OffsetsB=offsets_b, + ValuesB=values_b, + DenseSize=dense_size, + Out=values_out, + D=D, + stride_ad=values_a.stride(-2), + stride_bd=values_b.stride(-2), + stride_dense_batch=stride_dense_batch, + stride_od=values_out.stride(0), + IS_DENSE_A=is_dense_a, # pyre-ignore[6] + IS_DENSE_B=is_dense_b, # pyre-ignore[6] + BLOCK_D=BLOCK_D, + IS_REPLACE=is_replace, # pyre-ignore[6] + ) + else: + concat_2D_jagged_jagged_w_prefix[(max_seq_len, B)]( + OffsetsA=offsets_a, + ValuesA=values_a, + OffsetsB=offsets_b, + ValuesB=values_b, + Out=values_out, + D=D, + stride_ad=values_a.stride(-2), + stride_bd=values_b.stride(-2), + stride_od=values_out.stride(0), + n_prefix_from_B=n_prefix_from_right, + BLOCK_D=BLOCK_D, + ) + ctx.save_for_backward(offsets_a, offsets_b) + ctx.max_seq_len = max_seq_len + ctx.seq_len_a = seq_len_a + ctx.seq_len_b = seq_len_b + ctx.is_dense_a = is_dense_a + ctx.is_dense_b = is_dense_b + ctx.dense_size = dense_size + ctx.is_replace = is_replace + ctx.n_prefix_from_right = n_prefix_from_right + return values_out + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, d_out: torch.Tensor + ) -> Tuple[None, torch.Tensor, torch.Tensor, None, None, None, None]: + offsets_a, offsets_b = ctx.saved_tensors + is_dense_a, is_dense_b, is_replace = ( + ctx.is_dense_a, + ctx.is_dense_b, + ctx.is_replace, + ) + dense_size = ctx.dense_size + if is_dense_a: + B = offsets_b.shape[0] - 1 + else: + B = offsets_a.shape[0] - 1 + _, D = d_out.shape + BLOCK_D = triton.next_power_of_2(D) + values_a = torch.zeros( + (ctx.seq_len_a, D), device=d_out.device, dtype=d_out.dtype + ) + values_b = torch.empty( + (ctx.seq_len_b, D), device=d_out.device, dtype=d_out.dtype + ) + if ctx.n_prefix_from_right == 0: + split_2D_jagged[(ctx.max_seq_len, B)]( + JaggedIn=d_out, + DenseSize=dense_size, + OffsetsA=offsets_a, + OffsetsB=offsets_b, + OutA=values_a, + OutB=values_b, + D=D, + stride_id=d_out.stride(0), + stride_ad=values_a.stride(0), + stride_bd=values_b.stride(0), + BLOCK_D=BLOCK_D, + IS_DENSE_A=is_dense_a, + IS_DENSE_B=is_dense_b, + IS_REPLACE=is_replace, + ) + else: + split_2D_jagged_jagged_w_prefix[(ctx.max_seq_len, B)]( + JaggedIn=d_out, + OffsetsA=offsets_a, + OffsetsB=offsets_b, + OutA=values_a, + OutB=values_b, + D=D, + stride_id=d_out.stride(0), + stride_ad=values_a.stride(0), + stride_bd=values_b.stride(0), + n_prefix_to_B=ctx.n_prefix_from_right, + BLOCK_D=BLOCK_D, + ) + + if is_dense_a: + values_a = values_a.reshape((B, dense_size, D)) + elif is_dense_b: + values_b = values_b.reshape((B, dense_size, D)) + return None, values_a, values_b, None, None, None, None + + +class _Split2DJaggedFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + values: torch.Tensor, + max_seq_len: int, + offsets_a: Optional[torch.Tensor] = None, + offsets_b: Optional[torch.Tensor] = None, + dense_size: int = 0, + n_prefix_to_right: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + values = switch_to_contiguous_if_needed(values) + is_dense_a: bool = offsets_a is None + is_dense_b: bool = offsets_b is None + if is_dense_a: + L, _ = values.shape + assert offsets_b is not None + B = offsets_b.shape[0] - 1 + seq_len_a = dense_size * B + seq_len_b = L - seq_len_a + offsets_a = offsets_b.new_empty(0) + elif is_dense_b: + L, _ = values.shape + assert offsets_a is not None + B = offsets_a.shape[0] - 1 + seq_len_b = dense_size * B + seq_len_a = L - seq_len_b + offsets_b = offsets_a.new_empty(0) + else: + assert offsets_a is not None and offsets_b is not None + B = offsets_a.shape[0] - 1 + seq_len_a = int(offsets_a[-1].item()) + seq_len_b = int(offsets_b[-1].item()) + _, D = values.shape + BLOCK_D = triton.next_power_of_2(D) + values_a = torch.empty((seq_len_a, D), device=values.device, dtype=values.dtype) + values_b = torch.empty((seq_len_b, D), device=values.device, dtype=values.dtype) + if n_prefix_to_right == 0: + split_2D_jagged[(max_seq_len, B)]( + JaggedIn=values, + DenseSize=dense_size, + OffsetsA=offsets_a, + OffsetsB=offsets_b, + OutA=values_a, + OutB=values_b, + D=D, + stride_id=values.stride(0), + stride_ad=values_a.stride(0), + stride_bd=values_b.stride(0), + IS_DENSE_A=is_dense_a, # pyre-ignore[6] + IS_DENSE_B=is_dense_b, # pyre-ignore[6] + BLOCK_D=BLOCK_D, + IS_REPLACE=False, # pyre-ignore[6] + ) + else: + split_2D_jagged_jagged_w_prefix[(max_seq_len, B)]( + JaggedIn=values, + OffsetsA=offsets_a, + OffsetsB=offsets_b, + OutA=values_a, + OutB=values_b, + D=D, + stride_id=values.stride(0), + stride_ad=values_a.stride(0), + stride_bd=values_b.stride(0), + n_prefix_to_B=n_prefix_to_right, + BLOCK_D=BLOCK_D, + ) + if is_dense_a: + values_a = values_a.reshape(B, dense_size, D) + if is_dense_b: + values_b = values_b.reshape(B, dense_size, D) + ctx.save_for_backward(offsets_a, offsets_b) + ctx.max_seq_len = max_seq_len + ctx.seq_len_a = seq_len_a + ctx.seq_len_b = seq_len_b + ctx.is_dense_a = is_dense_a + ctx.is_dense_b = is_dense_b + ctx.dense_size = dense_size + ctx.B = B + ctx.D = D + ctx.n_prefix_to_right = n_prefix_to_right + return values_a, values_b + + @staticmethod + def backward(ctx, *d_values) -> Tuple[torch.Tensor, None, None, None, None, None]: + offsets_a, offsets_b = ctx.saved_tensors + is_dense_a, is_dense_b = ctx.is_dense_a, ctx.is_dense_b + values_a, values_b = d_values + if is_dense_a: + stride_dense_batch = values_a.stride(0) + elif is_dense_b: + stride_dense_batch = values_b.stride(0) + else: + stride_dense_batch = 0 + + BLOCK_D = triton.next_power_of_2(ctx.D) + dvalues = torch.empty( + (ctx.seq_len_a + ctx.seq_len_b, ctx.D), + device=values_a.device, + dtype=values_b.dtype, + ) + if ctx.n_prefix_to_right == 0: + concat_2D_jagged[(ctx.max_seq_len, ctx.B)]( + OffsetsA=offsets_a, + ValuesA=values_a, + OffsetsB=offsets_b, + ValuesB=values_b, + DenseSize=ctx.dense_size, + Out=dvalues, + D=ctx.D, + stride_ad=values_a.stride(-2), + stride_bd=values_b.stride(-2), + stride_dense_batch=stride_dense_batch, + stride_od=dvalues.stride(0), + IS_DENSE_A=is_dense_a, + IS_DENSE_B=is_dense_b, + BLOCK_D=BLOCK_D, + IS_REPLACE=False, # pyre-ignore[6] + ) + else: + concat_2D_jagged_jagged_w_prefix[(ctx.max_seq_len, ctx.B)]( + OffsetsA=offsets_a, + ValuesA=values_a, + OffsetsB=offsets_b, + ValuesB=values_b, + Out=dvalues, + D=ctx.D, + stride_ad=values_a.stride(-2), + stride_bd=values_b.stride(-2), + stride_od=dvalues.stride(0), + n_prefix_from_B=ctx.n_prefix_to_right, + BLOCK_D=BLOCK_D, + ) + + return dvalues, None, None, None, None, None + + +@torch.fx.wrap +def triton_jagged_dense_bmm_broadcast_add( + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: + return _JaggedDenseBmmBroadcastAddFunction.apply( + max_seq_len, seq_offsets, jagged, dense, bias + ) + + +@torch.fx.wrap +def triton_concat_2D_jagged( + max_seq_len: int, + values_a: torch.Tensor, + values_b: torch.Tensor, + offsets_a: Optional[torch.Tensor] = None, + offsets_b: Optional[torch.Tensor] = None, + is_replace: bool = False, + n_prefix_from_right: int = 0, +) -> torch.Tensor: + return _Concat2DJaggedFunction.apply( + max_seq_len, + values_a, + values_b, + offsets_a, + offsets_b, + is_replace, + n_prefix_from_right, + ) + + +@torch.fx.wrap +def triton_concat_2D_jagged_jagged( + max_seq_len_left: int, + offsets_left: torch.Tensor, + values_left: torch.Tensor, + max_seq_len_right: int, + offsets_right: torch.Tensor, + values_right: torch.Tensor, + is_replace: bool, + n_prefix_from_right: int, +) -> torch.Tensor: + return triton_concat_2D_jagged( + max_seq_len=max_seq_len_left + max_seq_len_right, + values_a=values_left, + values_b=values_right, + offsets_a=offsets_left, + offsets_b=offsets_right, + is_replace=is_replace, + n_prefix_from_right=n_prefix_from_right, + ) + + +@torch.fx.wrap +def triton_concat_2D_dense_jagged( + jagged_max_seq_len: int, + jagged_offsets: torch.Tensor, + jagged_values: torch.Tensor, + dense_values: torch.Tensor, +) -> torch.Tensor: + B, dense_size, D = dense_values.size() + max_seq_len = jagged_max_seq_len + dense_size + return triton_concat_2D_jagged( + max_seq_len=max_seq_len, + values_a=dense_values, + values_b=jagged_values, + offsets_a=None, + offsets_b=jagged_offsets, + ) + + +def triton_jagged_dense_bmm( + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, +) -> torch.Tensor: + return _JaggedDenseBmmFunction.apply(max_seq_len, seq_offsets, jagged, dense) + + +def triton_split_2D_jagged( + values: torch.Tensor, + max_seq_len: int, + offsets_a: Optional[torch.Tensor] = None, + offsets_b: Optional[torch.Tensor] = None, + dense_size: int = 0, + n_prefix_to_right: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _Split2DJaggedFunction.apply( + values, + max_seq_len, + offsets_a, + offsets_b, + dense_size, + n_prefix_to_right, + ) + + +def triton_jagged_dense_broadcast_add( + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, +) -> torch.Tensor: + return _JaggedDenseBroadcastAddFunction.apply( + max_seq_len, seq_offsets, jagged, dense + ) diff --git a/examples/commons/optimizer/__init__.py b/examples/commons/optimizer/__init__.py new file mode 100644 index 000000000..6f5a747fe --- /dev/null +++ b/examples/commons/optimizer/__init__.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass + + +@dataclass +class OptimizerParam: + """ + Configuration for the embedding optimizer. + + Args: + optimizer_str (str): The optimizer type as a string: ``'adam'`` | ``'sgd'``. + learning_rate (float): The learning rate for the optimizer. + adam_beta1 (float, optional): The beta1 parameter for the Adam optimizer. Defaults to 0.9. + adam_beta2 (float, optional): The beta2 parameter for the Adam optimizer. Defaults to 0.95. + adam_eps (float, optional): The epsilon parameter for the Adam optimizer. Defaults to 1e-08. + """ + + optimizer_str: str + learning_rate: float + adam_beta1: float = 0.9 + adam_beta2: float = 0.95 + adam_eps: float = 1e-08 + weight_decay: float = 0.01 diff --git a/examples/hstu/pipeline/train_pipeline.py b/examples/commons/pipeline/train_pipeline.py similarity index 99% rename from examples/hstu/pipeline/train_pipeline.py rename to examples/commons/pipeline/train_pipeline.py index 2af4c571b..8c4061b3c 100644 --- a/examples/hstu/pipeline/train_pipeline.py +++ b/examples/commons/pipeline/train_pipeline.py @@ -39,11 +39,8 @@ import nvtx import torch -from commons.utils.distributed_utils import collective_assert -from distributed.finalize_model_grads import finalize_model_grads -from megatron.core import parallel_state -from megatron.core.distributed.distributed_data_parallel import DistributedDataParallel -from pipeline.utils import ( +from commons.distributed.finalize_model_grads import finalize_model_grads +from commons.pipeline.utils import ( In, Out, PipelinedForward, @@ -59,6 +56,9 @@ _to_device, _wait_for_batch, ) +from commons.utils.distributed_utils import collective_assert +from megatron.core import parallel_state +from megatron.core.distributed.distributed_data_parallel import DistributedDataParallel from torch.autograd.profiler import record_function from torchrec.distributed.dist_data import KJTAllToAllTensorsAwaitable from torchrec.distributed.model_parallel import ShardedModule diff --git a/examples/hstu/pipeline/utils.py b/examples/commons/pipeline/utils.py similarity index 100% rename from examples/hstu/pipeline/utils.py rename to examples/commons/pipeline/utils.py diff --git a/examples/hstu/setup.py b/examples/commons/setup.py similarity index 100% rename from examples/hstu/setup.py rename to examples/commons/setup.py diff --git a/examples/commons/utils/hooks.py b/examples/commons/utils/hooks.py new file mode 100644 index 000000000..ba9dc94e7 --- /dev/null +++ b/examples/commons/utils/hooks.py @@ -0,0 +1,62 @@ +import torch + + +def module_hook_check_act_has_nan( + module, input, output, msg: str = "", print_nan_indices: bool = False +) -> torch.Tensor: + if isinstance(output, torch.Tensor) and torch.isnan(output).all(): + if print_nan_indices: + nan_indices = output.isnan().nonzero() + print(f"{msg} module {module} has nan output at indices {nan_indices}") + else: + print(f"{msg} module {module} has nan output") + assert False + return output + + +def tensor_hook_check_grad_has_nan( + grad: torch.Tensor, msg: str = "", print_nan_indices: bool = False +) -> torch.Tensor: + if grad.isnan().any(): + if print_nan_indices: + nan_indices = grad.isnan().nonzero() + print(f"{msg} grad has nan at indices {nan_indices}") + else: + print(f"{msg} grad has nan") + assert False + return grad + + +def module_hook_check_act_has_inf( + module, input, output, msg: str = "", print_inf_indices: bool = False +) -> torch.Tensor: + if isinstance(output, torch.Tensor) and torch.isinf(output).any(): + if print_inf_indices: + inf_indices = output.isinf().nonzero() + print(f"{msg} module {module} has inf output at indices {inf_indices}") + else: + print(f"{msg} module {module} has inf output") + assert False + return output + + +def tensor_hook_assert_grad_has_nan(grad: torch.Tensor, msg: str = "") -> torch.Tensor: + assert grad.isnan().any(), f"{msg} grad has nan" + return grad + + +def tensor_hook_check_grad_has_inf( + grad: torch.Tensor, msg: str = "", print_inf_indices: bool = False +) -> torch.Tensor: + if grad.isinf().any(): + if print_inf_indices: + inf_indices = grad.isinf().nonzero() + print(f"{msg} grad has inf at indices {inf_indices}") + else: + print(f"{msg} grad has inf") + return grad + + +def tensor_hook_print_grad(grad: torch.Tensor, msg: str = "") -> torch.Tensor: + print(f"{msg} grad[-1,...]: {grad[-1,...]}") + return grad diff --git a/examples/hstu/configs/__init__.py b/examples/hstu/configs/__init__.py index 640197e20..cc32f5c95 100644 --- a/examples/hstu/configs/__init__.py +++ b/examples/hstu/configs/__init__.py @@ -18,12 +18,7 @@ get_kvcache_config, get_kvcache_metadata_buffer, ) -from .task_config import ( - OptimizerParam, - RankingConfig, - RetrievalConfig, - ShardedEmbeddingConfig, -) +from .task_config import RankingConfig, RetrievalConfig __all__ = [ "hstu_config", @@ -36,8 +31,6 @@ "get_hstu_config", "RankingConfig", "RetrievalConfig", - "OptimizerParam", - "ShardedEmbeddingConfig", "KernelBackend", "HSTULayerType", "KVCacheMetadata", diff --git a/examples/hstu/configs/hstu_config.py b/examples/hstu/configs/hstu_config.py index 16a18fa61..0ad39e97f 100644 --- a/examples/hstu/configs/hstu_config.py +++ b/examples/hstu/configs/hstu_config.py @@ -167,10 +167,10 @@ def get_hstu_config( Create the HSTU configuration. Args: - hidden_size (int): The hidden dimension size. - kv_channels (int): Number of key-value channels (per attention head). - num_attention_heads (int): Number of attention heads. - num_layers (int): Number of attention layers. + hidden_size (int): The hidden dimension size. (TransformerConfig) + kv_channels (int): Number of key-value channels (per attention head). (TransformerConfig) + num_attention_heads (int): Number of attention heads. (TransformerConfig) + num_layers (int): Number of attention layers. (TransformerConfig) dtype (torch.dtype): Data type (e.g., torch.float16). hstu_preprocessing_config (Optional[HSTUPreprocessingConfig], optional): HSTU preprocessing config. Defaults to None. position_encoding_config (Optional[PositionEncodingConfig], optional): Position embedding config. Defaults to None. diff --git a/examples/hstu/configs/task_config.py b/examples/hstu/configs/task_config.py index caa177c95..931eddf72 100644 --- a/examples/hstu/configs/task_config.py +++ b/examples/hstu/configs/task_config.py @@ -15,60 +15,7 @@ from dataclasses import dataclass from typing import List, Tuple, cast - -@dataclass -class OptimizerParam: - """ - Configuration for the embedding optimizer. - - Args: - optimizer_str (str): The optimizer type as a string: ``'adam'`` | ``'sgd'``. - learning_rate (float): The learning rate for the optimizer. - adam_beta1 (float, optional): The beta1 parameter for the Adam optimizer. Defaults to 0.9. - adam_beta2 (float, optional): The beta2 parameter for the Adam optimizer. Defaults to 0.95. - adam_eps (float, optional): The epsilon parameter for the Adam optimizer. Defaults to 1e-08. - """ - - optimizer_str: str - learning_rate: float - adam_beta1: float = 0.9 - adam_beta2: float = 0.95 - adam_eps: float = 1e-08 - weight_decay: float = 0.01 - - -@dataclass -class ShardedEmbeddingConfig: - """ - Configuration for sharded embeddings with sharding type. Inherits from BaseShardedEmbeddingConfig. - - Args: - config (EmbeddingConfig): The embedding configuration. - sharding_type (str): The type of sharding, ``'data_parallel'`` | ``'model_parallel'``. - """ - - """ - Base configuration for sharded embeddings. - - Args: - feature_names (List[str]): The name of the features in this embedding. - table_name (str): The name of the table. - vocab_size (int): The size of the vocabulary. - dim (int): The dimension size of the embeddings. - sharding_type (str): The type of sharding, ``'data_parallel'`` | ``'model_parallel'``. - """ - - feature_names: List[str] - table_name: str - vocab_size: int - dim: int - sharding_type: str - - def __post_init__(self): - assert self.sharding_type in [ - "data_parallel", - "model_parallel", - ], "sharding type should be data_parallel or model_parallel" +from commons.modules.embedding import ShardedEmbeddingConfig @dataclass diff --git a/examples/hstu/dataset/__init__.py b/examples/hstu/datasets/__init__.py similarity index 100% rename from examples/hstu/dataset/__init__.py rename to examples/hstu/datasets/__init__.py diff --git a/examples/hstu/dataset/dummy_dataset.py b/examples/hstu/datasets/dummy_dataset.py similarity index 100% rename from examples/hstu/dataset/dummy_dataset.py rename to examples/hstu/datasets/dummy_dataset.py diff --git a/examples/hstu/dataset/inference_dataset.py b/examples/hstu/datasets/inference_dataset.py similarity index 99% rename from examples/hstu/dataset/inference_dataset.py rename to examples/hstu/datasets/inference_dataset.py index a5db14464..37f8dbb97 100644 --- a/examples/hstu/dataset/inference_dataset.py +++ b/examples/hstu/datasets/inference_dataset.py @@ -34,7 +34,7 @@ import numpy as np import pandas as pd import torch -from dataset.utils import Batch, RankingBatch +from datasets.utils import Batch, RankingBatch from torch.utils.data.dataset import IterableDataset from torchrec.sparse.jagged_tensor import KeyedJaggedTensor diff --git a/examples/hstu/dataset/random_inference_dataset.py b/examples/hstu/datasets/random_inference_dataset.py similarity index 100% rename from examples/hstu/dataset/random_inference_dataset.py rename to examples/hstu/datasets/random_inference_dataset.py diff --git a/examples/hstu/dataset/sequence_dataset.py b/examples/hstu/datasets/sequence_dataset.py similarity index 99% rename from examples/hstu/dataset/sequence_dataset.py rename to examples/hstu/datasets/sequence_dataset.py index 1a256dd7b..720e8db47 100644 --- a/examples/hstu/dataset/sequence_dataset.py +++ b/examples/hstu/datasets/sequence_dataset.py @@ -35,7 +35,7 @@ import pandas as pd import torch from commons.utils.logger import print_rank_0 -from dataset.utils import Batch, RankingBatch, RetrievalBatch +from datasets.utils import Batch, RankingBatch, RetrievalBatch from preprocessor import get_common_preprocessors from torch.utils.data.dataset import IterableDataset from torchrec.sparse.jagged_tensor import KeyedJaggedTensor diff --git a/examples/hstu/dataset/utils.py b/examples/hstu/datasets/utils.py similarity index 100% rename from examples/hstu/dataset/utils.py rename to examples/hstu/datasets/utils.py diff --git a/examples/hstu/inference/benchmark/inference_benchmark.py b/examples/hstu/inference/benchmark/inference_benchmark.py index 534ef8e87..b5d1522e4 100755 --- a/examples/hstu/inference/benchmark/inference_benchmark.py +++ b/examples/hstu/inference/benchmark/inference_benchmark.py @@ -21,8 +21,8 @@ get_inference_hstu_config, get_kvcache_config, ) -from dataset.random_inference_dataset import RandomInferenceDataGenerator -from dataset.utils import FeatureConfig +from datasets.random_inference_dataset import RandomInferenceDataGenerator +from datasets.utils import FeatureConfig sys.path.append("./model/") from inference_ranking_gr import get_inference_ranking_gr diff --git a/examples/hstu/inference/benchmark/paged_hstu_with_kvcache_benchmark.py b/examples/hstu/inference/benchmark/paged_hstu_with_kvcache_benchmark.py index 19430d12b..f0791f840 100755 --- a/examples/hstu/inference/benchmark/paged_hstu_with_kvcache_benchmark.py +++ b/examples/hstu/inference/benchmark/paged_hstu_with_kvcache_benchmark.py @@ -25,7 +25,7 @@ get_inference_hstu_config, get_kvcache_config, ) -from dataset.utils import FeatureConfig +from datasets.utils import FeatureConfig from modules.inference_dense_module import InferenceDenseModule, copy_jagged_metadata from modules.jagged_data import JaggedData diff --git a/examples/hstu/inference/inference_gr_ranking.py b/examples/hstu/inference/inference_gr_ranking.py index e969080f0..f29a28797 100644 --- a/examples/hstu/inference/inference_gr_ranking.py +++ b/examples/hstu/inference/inference_gr_ranking.py @@ -28,9 +28,9 @@ get_inference_hstu_config, get_kvcache_config, ) -from dataset import get_data_loader -from dataset.inference_dataset import InferenceDataset -from dataset.sequence_dataset import get_dataset +from datasets import get_data_loader +from datasets.inference_dataset import InferenceDataset +from datasets.sequence_dataset import get_dataset from modules.metrics import get_multi_event_metric_module from preprocessor import get_common_preprocessors from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor diff --git a/examples/hstu/inference/triton/hstu_model/client.py b/examples/hstu/inference/triton/hstu_model/client.py index 1c3807f01..14c2b80df 100644 --- a/examples/hstu/inference/triton/hstu_model/client.py +++ b/examples/hstu/inference/triton/hstu_model/client.py @@ -30,8 +30,8 @@ import torch import tritonclient.http as httpclient from commons.utils.stringify import stringify_dict -from dataset import get_data_loader -from dataset.sequence_dataset import get_dataset +from datasets import get_data_loader +from datasets.sequence_dataset import get_dataset from modules.metrics import get_multi_event_metric_module from preprocessor import get_common_preprocessors from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor diff --git a/examples/hstu/inference/triton/hstu_model/model.py b/examples/hstu/inference/triton/hstu_model/model.py index d4d23b7e9..8a5261748 100644 --- a/examples/hstu/inference/triton/hstu_model/model.py +++ b/examples/hstu/inference/triton/hstu_model/model.py @@ -42,7 +42,7 @@ RankingConfig, get_inference_hstu_config, ) -from dataset.utils import Batch +from datasets.utils import Batch from modules.inference_dense_module import get_inference_dense_model from torch.utils.dlpack import from_dlpack, to_dlpack from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor diff --git a/examples/hstu/model/base_model.py b/examples/hstu/model/base_model.py index d72839a75..223aa35d0 100644 --- a/examples/hstu/model/base_model.py +++ b/examples/hstu/model/base_model.py @@ -16,7 +16,7 @@ from typing import Tuple import torch -from dataset.utils import Batch +from datasets.utils import Batch class BaseModel(torch.nn.Module): diff --git a/examples/hstu/model/inference_ranking_gr.py b/examples/hstu/model/inference_ranking_gr.py index e42c1f371..2e09247f2 100755 --- a/examples/hstu/model/inference_ranking_gr.py +++ b/examples/hstu/model/inference_ranking_gr.py @@ -17,7 +17,7 @@ import torch from configs import InferenceHSTUConfig, KVCacheConfig, RankingConfig -from dataset.utils import Batch +from datasets.utils import Batch from modules.inference_dense_module import InferenceDenseModule from modules.inference_embedding import InferenceEmbedding diff --git a/examples/hstu/model/ranking_gr.py b/examples/hstu/model/ranking_gr.py index 18f7bd4a8..38d449de4 100644 --- a/examples/hstu/model/ranking_gr.py +++ b/examples/hstu/model/ranking_gr.py @@ -15,12 +15,15 @@ from typing import Any, Dict, Tuple import torch +from commons.distributed.dmp_to_tp import ( + dmp_batch_to_tp, + jt_dict_grad_scaling_and_allgather, +) +from commons.modules.embedding import ShardedEmbedding from configs import HSTUConfig, RankingConfig -from dataset.utils import RankingBatch -from distributed.dmp_to_tp import dmp_batch_to_tp, jt_dict_grad_scaling_and_allgather +from datasets.utils import RankingBatch from megatron.core import parallel_state from model.base_model import BaseModel -from modules.embedding import ShardedEmbedding from modules.hstu_block import HSTUBlock from modules.metrics import get_multi_event_metric_module from modules.mlp import MLP diff --git a/examples/hstu/model/retrieval_gr.py b/examples/hstu/model/retrieval_gr.py index bf0826f07..f62f49676 100644 --- a/examples/hstu/model/retrieval_gr.py +++ b/examples/hstu/model/retrieval_gr.py @@ -15,21 +15,21 @@ from typing import Any, Tuple import torch +from commons.modules.embedding import ShardedEmbedding +from commons.ops.length_to_offsets import length_to_complete_offsets +from commons.ops.triton_ops.triton_jagged import ( # type: ignore[attr-defined] + triton_split_2D_jagged, +) from commons.utils.nvtx_op import output_nvtx_hook from configs import HSTUConfig, RetrievalConfig -from dataset.utils import RetrievalBatch +from datasets.utils import RetrievalBatch from megatron.core import parallel_state from model.base_model import BaseModel -from modules.embedding import ShardedEmbedding from modules.hstu_block import HSTUBlock from modules.negatives_sampler import InBatchNegativesSampler from modules.output_postprocessors import L2NormEmbeddingPostprocessor from modules.sampled_softmax_loss import SampledSoftmaxLoss from modules.similarity.dot_product import DotProductSimilarity -from ops.length_to_offsets import length_to_complete_offsets -from ops.triton_ops.triton_jagged import ( # type: ignore[attr-defined] - triton_split_2D_jagged, -) class RetrievalGR(BaseModel): diff --git a/examples/hstu/modules/debug/debug_hstu_layer.py b/examples/hstu/modules/debug/debug_hstu_layer.py index bbf8d45e2..92be1e1ef 100644 --- a/examples/hstu/modules/debug/debug_hstu_layer.py +++ b/examples/hstu/modules/debug/debug_hstu_layer.py @@ -20,6 +20,7 @@ import nvtx import torch import torch.nn.functional as F +from commons.ops.collective_ops import gather_along_last_dim, split_along_last_dim from commons.utils.distributed_utils import ( collective_assert_tensor, grad_collective_equal_assert_hook, @@ -35,7 +36,6 @@ from modules.hstu_attention import create_hstu_attention from modules.jagged_data import JaggedData from modules.utils import init_mlp_weights_optional_bias -from ops.collective_ops import gather_along_last_dim, split_along_last_dim from ops.pt_ops.pt_norm_mul_dropout import pytorch_norm_mul_dropout from ops.triton_ops.triton_norm_mul_dropout import triton_norm_mul_dropout diff --git a/examples/hstu/modules/hstu_block.py b/examples/hstu/modules/hstu_block.py index ea1a2da23..6a77bfe51 100644 --- a/examples/hstu/modules/hstu_block.py +++ b/examples/hstu/modules/hstu_block.py @@ -5,7 +5,7 @@ import torch from commons.utils.nvtx_op import output_nvtx_hook from configs.hstu_config import HSTUConfig, HSTULayerType -from dataset.utils import RankingBatch, RetrievalBatch +from datasets.utils import RankingBatch, RetrievalBatch from megatron.core.transformer.module import MegatronModule from modules.debug.debug_hstu_layer import HSTULayer as DebugHSTULayer from modules.fused_hstu_layer import FusedHSTULayer diff --git a/examples/hstu/modules/hstu_block_inference.py b/examples/hstu/modules/hstu_block_inference.py index 70bfca945..734d23551 100644 --- a/examples/hstu/modules/hstu_block_inference.py +++ b/examples/hstu/modules/hstu_block_inference.py @@ -5,7 +5,7 @@ import torch from configs import InferenceHSTUConfig, KVCacheConfig -from dataset.utils import Batch +from datasets.utils import Batch from modules.hstu_processor import HSTUBlockPostprocessor, HSTUBlockPreprocessor from modules.jagged_data import JaggedData from modules.paged_hstu_infer_layer import PagedHSTUInferLayer diff --git a/examples/hstu/modules/hstu_processor.py b/examples/hstu/modules/hstu_processor.py index 38c99c4dc..567b2845b 100644 --- a/examples/hstu/modules/hstu_processor.py +++ b/examples/hstu/modules/hstu_processor.py @@ -16,16 +16,16 @@ from typing import Dict, Optional, Union import torch +from commons.ops.cuda_ops.JaggedTensorOpFunction import jagged_2D_tensor_concat +from commons.ops.length_to_offsets import length_to_complete_offsets +from commons.ops.triton_ops.triton_jagged import triton_split_2D_jagged from commons.utils.nvtx_op import output_nvtx_hook from configs.hstu_config import HSTUConfig from configs.inference_config import InferenceHSTUConfig -from dataset.utils import RankingBatch +from datasets.utils import RankingBatch from modules.jagged_data import JaggedData, pad_jd_values, unpad_jd_values from modules.mlp import MLP from modules.position_encoder import HSTUPositionalEncoder -from ops.cuda_ops.JaggedTensorOpFunction import jagged_2D_tensor_concat -from ops.length_to_offsets import length_to_complete_offsets -from ops.triton_ops.triton_jagged import triton_split_2D_jagged from torchrec.sparse.jagged_tensor import JaggedTensor try: diff --git a/examples/hstu/modules/inference_dense_module.py b/examples/hstu/modules/inference_dense_module.py index 5435731fb..36e17d69d 100755 --- a/examples/hstu/modules/inference_dense_module.py +++ b/examples/hstu/modules/inference_dense_module.py @@ -25,7 +25,7 @@ copy_kvcache_metadata, get_kvcache_metadata_buffer, ) -from dataset.utils import Batch +from datasets.utils import Batch from modules.hstu_block_inference import HSTUBlockInference from modules.jagged_data import JaggedData from modules.mlp import MLP diff --git a/examples/hstu/modules/metrics/metric_modules.py b/examples/hstu/modules/metrics/metric_modules.py index ffe6479c7..bcb871f33 100644 --- a/examples/hstu/modules/metrics/metric_modules.py +++ b/examples/hstu/modules/metrics/metric_modules.py @@ -23,11 +23,11 @@ import numpy as np import torch import torchmetrics.classification as classification_metrics +from commons.ops.collective_ops import grouped_allgatherv_tensor_list from commons.utils.nvtx_op import output_nvtx_hook from dynamicemb.planner import ( DynamicEmbeddingShardingPlanner as DynamicEmbeddingShardingPlanner, ) -from ops.collective_ops import grouped_allgatherv_tensor_list try: from megatron.core import parallel_state diff --git a/examples/hstu/modules/native_hstu_layer.py b/examples/hstu/modules/native_hstu_layer.py index 7c7c86343..c157f6d0f 100644 --- a/examples/hstu/modules/native_hstu_layer.py +++ b/examples/hstu/modules/native_hstu_layer.py @@ -19,6 +19,7 @@ import nvtx import torch import torch.nn.functional as F +from commons.ops.collective_ops import gather_along_last_dim, split_along_first_dim from commons.utils.clear_tensor_data import clear_tensor_data from commons.utils.nvtx_op import output_nvtx_hook, register_setter_and_getter_for_nvtx from configs import HSTUConfig @@ -34,7 +35,6 @@ from modules.hstu_attention import create_hstu_attention from modules.jagged_data import JaggedData from modules.tp_layer_norm import TPLayerNormMulDropout -from ops.collective_ops import gather_along_last_dim, split_along_first_dim from ops.triton_ops.triton_layer_norm import triton_layer_norm diff --git a/examples/hstu/modules/position_encoder.py b/examples/hstu/modules/position_encoder.py index 52fb3b3a2..a1194a23a 100644 --- a/examples/hstu/modules/position_encoder.py +++ b/examples/hstu/modules/position_encoder.py @@ -34,8 +34,11 @@ from typing import Optional import torch +from commons.ops.triton_ops.common import ( + set_static_max_seq_lens, + set_use_runtime_max_seq_len, +) from commons.utils.nvtx_op import output_nvtx_hook -from ops.triton_ops.common import set_static_max_seq_lens, set_use_runtime_max_seq_len from ops.triton_ops.triton_position import ( # type: ignore[attr-defined] triton_add_position_embeddings, triton_add_timestamp_positional_embeddings, diff --git a/examples/hstu/modules/tp_layer_norm.py b/examples/hstu/modules/tp_layer_norm.py index 766435374..f937fce8d 100644 --- a/examples/hstu/modules/tp_layer_norm.py +++ b/examples/hstu/modules/tp_layer_norm.py @@ -1,11 +1,11 @@ import megatron.core.parallel_state as parallel_state import torch -from megatron.core.tensor_parallel.mappings import all_to_all_hp2sp, all_to_all_sp2hp -from ops.collective_ops import ( +from commons.ops.collective_ops import ( gather_along_first_dim, gather_along_last_dim, split_along_last_dim, ) +from megatron.core.tensor_parallel.mappings import all_to_all_hp2sp, all_to_all_sp2hp from ops.pt_ops.pt_norm_mul_dropout import pytorch_norm_mul_dropout from ops.triton_ops.triton_layer_norm import triton_layer_norm from ops.triton_ops.triton_norm_mul_dropout import triton_norm_mul_dropout diff --git a/examples/hstu/ops/triton_ops/triton_addmm.py b/examples/hstu/ops/triton_ops/triton_addmm.py index cffed4915..62cf03cda 100644 --- a/examples/hstu/ops/triton_ops/triton_addmm.py +++ b/examples/hstu/ops/triton_ops/triton_addmm.py @@ -24,7 +24,7 @@ # @manual=//triton:triton import triton.language as tl -from ops.triton_ops.common import triton_autotune +from commons.ops.triton_ops.common import triton_autotune ENABLE_FULL_TURNING_SPACE = False diff --git a/examples/hstu/ops/triton_ops/triton_hstu_attention.py b/examples/hstu/ops/triton_ops/triton_hstu_attention.py index 9f5261df6..ae4af8d17 100644 --- a/examples/hstu/ops/triton_ops/triton_hstu_attention.py +++ b/examples/hstu/ops/triton_ops/triton_hstu_attention.py @@ -37,7 +37,7 @@ # @manual=//triton:triton import triton.language as tl -from ops.triton_ops.common import ( +from commons.ops.triton_ops.common import ( NamedSpecType, VersionedSpec, autotune_max_seq_len, diff --git a/examples/hstu/ops/triton_ops/triton_jagged.py b/examples/hstu/ops/triton_ops/triton_jagged.py index 69a56aef2..026eb327b 100644 --- a/examples/hstu/ops/triton_ops/triton_jagged.py +++ b/examples/hstu/ops/triton_ops/triton_jagged.py @@ -38,7 +38,7 @@ # @manual=//triton:triton import triton.language as tl -from ops.triton_ops.common import ( +from commons.ops.triton_ops.common import ( autotune_max_seq_len, switch_to_contiguous_if_needed, triton_autotune, diff --git a/examples/hstu/ops/triton_ops/triton_layer_norm.py b/examples/hstu/ops/triton_ops/triton_layer_norm.py index 1468b21af..b49180041 100644 --- a/examples/hstu/ops/triton_ops/triton_layer_norm.py +++ b/examples/hstu/ops/triton_ops/triton_layer_norm.py @@ -24,7 +24,10 @@ # @manual=//triton:triton import triton.language as tl -from ops.triton_ops.common import switch_to_contiguous_if_needed, triton_autotune +from commons.ops.triton_ops.common import ( + switch_to_contiguous_if_needed, + triton_autotune, +) @triton.jit diff --git a/examples/hstu/ops/triton_ops/triton_norm_mul_dropout.py b/examples/hstu/ops/triton_ops/triton_norm_mul_dropout.py index eeb44009e..8be556b00 100644 --- a/examples/hstu/ops/triton_ops/triton_norm_mul_dropout.py +++ b/examples/hstu/ops/triton_ops/triton_norm_mul_dropout.py @@ -24,7 +24,10 @@ # @manual=//triton:triton import triton.language as tl -from ops.triton_ops.common import switch_to_contiguous_if_needed, triton_autotune +from commons.ops.triton_ops.common import ( + switch_to_contiguous_if_needed, + triton_autotune, +) @triton.jit diff --git a/examples/hstu/ops/triton_ops/triton_position.py b/examples/hstu/ops/triton_ops/triton_position.py index a7bd89c7a..0fcfa52ee 100644 --- a/examples/hstu/ops/triton_ops/triton_position.py +++ b/examples/hstu/ops/triton_ops/triton_position.py @@ -44,7 +44,7 @@ except OSError: pass -from ops.triton_ops.common import ( +from commons.ops.triton_ops.common import ( autotune_max_seq_len, prev_power_of_2, switch_to_contiguous_if_needed, diff --git a/examples/hstu/ops/triton_ops/triton_silu.py b/examples/hstu/ops/triton_ops/triton_silu.py index 59251abee..e68cc7591 100644 --- a/examples/hstu/ops/triton_ops/triton_silu.py +++ b/examples/hstu/ops/triton_ops/triton_silu.py @@ -28,7 +28,7 @@ # @manual=//triton:triton from triton.language.math import fast_dividef -from ops.triton_ops.common import triton_autotune +from commons.ops.triton_ops.common import triton_autotune def silu_configs(): diff --git a/examples/hstu/test/conftest.py b/examples/hstu/test/conftest.py new file mode 100644 index 000000000..c31fd3ccb --- /dev/null +++ b/examples/hstu/test/conftest.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Pytest configuration for limiting parametrized test cases during debugging. + +Usage: + # Run only the first parameter combination for all parametrized tests + PYTEST_FIRST_PARAM_ONLY=1 pytest + + # Run all parameter combinations (default behavior) + pytest +""" + +import os + + +def pytest_configure(config): + """Register custom marker for first parameter only mode.""" + config.addinivalue_line( + "markers", + "first_param_only: automatically added when PYTEST_FIRST_PARAM_ONLY is set", + ) + + +def pytest_generate_tests(metafunc): + """ + Hook to modify parametrized tests to only run the first parameter combination. + + This is useful during debugging when you want to quickly test if the test + infrastructure works without running all parameter combinations. + + Set environment variable PYTEST_FIRST_PARAM_ONLY=1 to enable this behavior. + """ + # Check if we should limit to first parameter only + if os.getenv("PYTEST_FIRST_PARAM_ONLY", "0") == "1": + # Check if this test function has parametrize + if hasattr(metafunc, "definition") and hasattr(metafunc.definition, "callspec"): + # This will be handled by pytest_collection_modifyitems + pass + + +def pytest_collection_modifyitems(config, items): + """ + Modify collected test items to only keep the first parameter combination + for each parametrized test function when PYTEST_FIRST_PARAM_ONLY is set. + """ + if os.getenv("PYTEST_FIRST_PARAM_ONLY", "0") != "1": + return + + # Group items by their base test name (without parameter suffix) + test_groups = {} + for item in items: + # Get the base test name (remove parameter suffix like [param0]) + base_name = item.nodeid.split("[")[0] + if base_name not in test_groups: + test_groups[base_name] = [] + test_groups[base_name].append(item) + + # Filter to keep only the first item from each group that has multiple items + items_to_keep = [] + for base_name, group_items in test_groups.items(): + if len(group_items) > 1: + # This is a parametrized test, keep only the first one + items_to_keep.append(group_items[0]) + print( + f" [PYTEST_FIRST_PARAM_ONLY] Keeping only first parameter for: {base_name}" + ) + else: + # Not parametrized or only one parameter, keep it + items_to_keep.extend(group_items) + + # Replace the items list + items[:] = items_to_keep diff --git a/examples/hstu/test/tensor_parallel/test_tp_hstu_layer.py b/examples/hstu/test/tensor_parallel/test_tp_hstu_layer.py index 4ab3cf789..cdd7165f7 100644 --- a/examples/hstu/test/tensor_parallel/test_tp_hstu_layer.py +++ b/examples/hstu/test/tensor_parallel/test_tp_hstu_layer.py @@ -20,17 +20,17 @@ import pytest import torch from commons.checkpoint import get_unwrapped_module +from commons.distributed.finalize_model_grads import finalize_model_grads +from commons.ops.length_to_offsets import length_to_complete_offsets from commons.utils.distributed_utils import collective_assert, collective_assert_tensor from commons.utils.hstu_assert_close import hstu_close from configs.hstu_config import HSTULayerType, KernelBackend -from distributed.finalize_model_grads import finalize_model_grads from megatron.core import parallel_state from megatron.core.tensor_parallel.mappings import ( gather_from_sequence_parallel_region, scatter_to_sequence_parallel_region, ) from modules.jagged_data import JaggedData, pad_jd_values, unpad_jd_values -from ops.length_to_offsets import length_to_complete_offsets from test_utils import ( compare_tpN_to_debug_weights, create_hstu_layer_and_optimizer, diff --git a/examples/hstu/test/tensor_parallel/test_tp_ranking_gr.py b/examples/hstu/test/tensor_parallel/test_tp_ranking_gr.py index 754f5c397..5013b0d7b 100755 --- a/examples/hstu/test/tensor_parallel/test_tp_ranking_gr.py +++ b/examples/hstu/test/tensor_parallel/test_tp_ranking_gr.py @@ -20,11 +20,11 @@ import pytest import torch from commons.checkpoint import get_unwrapped_module +from commons.pipeline.train_pipeline import JaggedMegatronTrainNonePipeline from commons.utils.distributed_utils import collective_assert, collective_assert_tensor from commons.utils.hstu_assert_close import hstu_close from configs import HSTULayerType, KernelBackend from megatron.core import parallel_state -from pipeline.train_pipeline import JaggedMegatronTrainNonePipeline from test_utils import ( compare_tpN_to_debug_weights, create_model, diff --git a/examples/hstu/test/test_checkpointing.py b/examples/hstu/test/test_checkpointing.py index a22e6a7bb..36d5a5238 100644 --- a/examples/hstu/test/test_checkpointing.py +++ b/examples/hstu/test/test_checkpointing.py @@ -21,8 +21,8 @@ import pytest import torch import torch.distributed as dist +from commons.distributed.finalize_model_grads import finalize_model_grads from commons.utils.distributed_utils import collective_assert -from distributed.finalize_model_grads import finalize_model_grads from test_utils import assert_equal_two_state_dict, create_model @@ -164,7 +164,7 @@ def test_checkpoint_model( init.destroy_global_state() -from modules.embedding import DataParallelEmbeddingCollection +from commons.modules.embedding import DataParallelEmbeddingCollection from torchrec.distributed.planner import EmbeddingShardingPlanner from torchrec.distributed.planner.types import ParameterConstraints from torchrec.distributed.types import BoundsCheckMode, ShardingEnv, ShardingType diff --git a/examples/hstu/test/test_collective.py b/examples/hstu/test/test_collective.py index 25be0bb2b..0c8d948aa 100644 --- a/examples/hstu/test/test_collective.py +++ b/examples/hstu/test/test_collective.py @@ -16,13 +16,13 @@ import fbgemm_gpu # for jagged_to_padded_dense import pytest import torch -from megatron.core import parallel_state, tensor_parallel -from ops.collective_ops import ( +from commons.ops.collective_ops import ( gather_along_first_dim, gatherv_along_first_dim, grouped_allgatherv_tensor_list, ) -from ops.length_to_offsets import length_to_complete_offsets +from commons.ops.length_to_offsets import length_to_complete_offsets +from megatron.core import parallel_state, tensor_parallel def get_source_and_ref_tensor(shape=(128, 1), dtype=torch.float): diff --git a/examples/hstu/test/test_dataset.py b/examples/hstu/test/test_dataset.py index c3552f91e..88161d10f 100644 --- a/examples/hstu/test/test_dataset.py +++ b/examples/hstu/test/test_dataset.py @@ -18,10 +18,10 @@ import fbgemm_gpu # to load permute_2D_sparse_data import pytest import torch -from dataset import get_data_loader -from dataset.dummy_dataset import DummySequenceDataset -from dataset.sequence_dataset import get_dataset -from dataset.utils import FeatureConfig, RankingBatch, RetrievalBatch, is_batch_valid +from datasets import get_data_loader +from datasets.dummy_dataset import DummySequenceDataset +from datasets.sequence_dataset import get_dataset +from datasets.utils import FeatureConfig, RankingBatch, RetrievalBatch, is_batch_valid from torch import distributed as dist from torchrec.sparse.jagged_tensor import KeyedJaggedTensor diff --git a/examples/hstu/test/test_hstu_block_inference.py b/examples/hstu/test/test_hstu_block_inference.py index 822b9413c..7a5a2ef2d 100755 --- a/examples/hstu/test/test_hstu_block_inference.py +++ b/examples/hstu/test/test_hstu_block_inference.py @@ -21,8 +21,8 @@ get_inference_hstu_config, get_kvcache_config, ) -from dataset.random_inference_dataset import RandomInferenceDataGenerator -from dataset.utils import FeatureConfig +from datasets.random_inference_dataset import RandomInferenceDataGenerator +from datasets.utils import FeatureConfig sys.path.append("./model/") from inference_ranking_gr import InferenceRankingGR diff --git a/examples/hstu/test/test_hstu_layer.py b/examples/hstu/test/test_hstu_layer.py index 4ab5d35d0..0fd82e0af 100644 --- a/examples/hstu/test/test_hstu_layer.py +++ b/examples/hstu/test/test_hstu_layer.py @@ -18,6 +18,7 @@ import fbgemm_gpu # pylint: disable-unused-import import pytest import torch +from commons.ops.length_to_offsets import length_to_complete_offsets from commons.utils.hstu_assert_close import assert_hstu_close from configs import get_hstu_config from configs.hstu_config import HSTULayerType, KernelBackend @@ -25,7 +26,6 @@ from modules.debug.debug_hstu_layer import HSTULayer as DebugHSTULayer from modules.fused_hstu_layer import FusedHSTULayer from modules.jagged_data import JaggedData -from ops.length_to_offsets import length_to_complete_offsets from test_utils import init_fused_weights_from_debug diff --git a/examples/hstu/test/test_hstu_op.py b/examples/hstu/test/test_hstu_op.py index 356a6a9d2..dc6fa0051 100644 --- a/examples/hstu/test/test_hstu_op.py +++ b/examples/hstu/test/test_hstu_op.py @@ -19,6 +19,7 @@ import fbgemm_gpu # pylint: disable-unused-import import pytest import torch +from commons.ops.length_to_offsets import length_to_complete_offsets from commons.utils.hstu_assert_close import assert_hstu_close from configs import get_hstu_config from configs.hstu_config import HSTULayerType, KernelBackend @@ -27,7 +28,6 @@ from modules.hstu_attention import create_hstu_attention from modules.jagged_data import JaggedData from ops.fused_hstu_op import fused_hstu_op -from ops.length_to_offsets import length_to_complete_offsets def generate_or_copy_parameters( diff --git a/examples/hstu/test/test_hstu_preprocess.py b/examples/hstu/test/test_hstu_preprocess.py index 22925c7cb..4e6b80854 100644 --- a/examples/hstu/test/test_hstu_preprocess.py +++ b/examples/hstu/test/test_hstu_preprocess.py @@ -16,7 +16,7 @@ import pytest import torch from configs import get_hstu_config -from dataset.utils import Batch, FeatureConfig +from datasets.utils import Batch, FeatureConfig from modules.hstu_block import HSTUBlock from torchrec.sparse.jagged_tensor import KeyedJaggedTensor diff --git a/examples/hstu/test/test_jagged_tensor.py b/examples/hstu/test/test_jagged_tensor.py index b2b81c511..fa50237e7 100644 --- a/examples/hstu/test/test_jagged_tensor.py +++ b/examples/hstu/test/test_jagged_tensor.py @@ -3,7 +3,7 @@ import pytest import torch import torch.distributed as dist -from ops.cuda_ops.JaggedTensorOpFunction import jagged_2D_tensor_concat +from commons.ops.cuda_ops.JaggedTensorOpFunction import jagged_2D_tensor_concat from torchrec.sparse.jagged_tensor import JaggedTensor backend = "nccl" diff --git a/examples/hstu/test/test_kvcache.py b/examples/hstu/test/test_kvcache.py index de0f8ccd9..90d3f4046 100755 --- a/examples/hstu/test/test_kvcache.py +++ b/examples/hstu/test/test_kvcache.py @@ -21,7 +21,7 @@ get_inference_hstu_config, get_kvcache_config, ) -from dataset.utils import Batch, FeatureConfig +from datasets.utils import Batch, FeatureConfig sys.path.append("./model/") from inference_ranking_gr import InferenceRankingGR diff --git a/examples/hstu/test/test_metrics.py b/examples/hstu/test/test_metrics.py index c39df3bd7..f458f76c7 100644 --- a/examples/hstu/test/test_metrics.py +++ b/examples/hstu/test/test_metrics.py @@ -19,9 +19,9 @@ import commons.utils.initialize as init import pytest import torch +from commons.ops.collective_ops import grouped_allgatherv_tensor_list from megatron.core import parallel_state, tensor_parallel from modules.metrics.metric_modules import RetrievalTaskMetricWithSampling -from ops.collective_ops import grouped_allgatherv_tensor_list @pytest.mark.parametrize("num_embeddings", [10000]) diff --git a/examples/hstu/test/test_pipeline.py b/examples/hstu/test/test_pipeline.py index 517520595..ffa62c666 100644 --- a/examples/hstu/test/test_pipeline.py +++ b/examples/hstu/test/test_pipeline.py @@ -21,13 +21,13 @@ import pytest import torch import torch.distributed as dist -from commons.utils.distributed_utils import collective_assert -from distributed.finalize_model_grads import finalize_model_grads -from pipeline.train_pipeline import ( +from commons.distributed.finalize_model_grads import finalize_model_grads +from commons.pipeline.train_pipeline import ( JaggedMegatronPrefetchTrainPipelineSparseDist, JaggedMegatronTrainNonePipeline, JaggedMegatronTrainPipelineSparseDist, ) +from commons.utils.distributed_utils import collective_assert from test_utils import create_model diff --git a/examples/hstu/test/test_position_encoder.py b/examples/hstu/test/test_position_encoder.py index 0bd728bc8..b0bbc7e14 100644 --- a/examples/hstu/test/test_position_encoder.py +++ b/examples/hstu/test/test_position_encoder.py @@ -14,10 +14,10 @@ # limitations under the License. import pytest import torch +from commons.ops.length_to_offsets import length_to_complete_offsets -# from dataset.utils import Batch, FeatureConfig +# from datasets.utils import Batch, FeatureConfig from modules.position_encoder import HSTUPositionalEncoder -from ops.length_to_offsets import length_to_complete_offsets from ops.triton_ops.triton_jagged import triton_concat_2D_jagged diff --git a/examples/hstu/test_utils.py b/examples/hstu/test_utils.py index 7c4540f45..40f08f29b 100755 --- a/examples/hstu/test_utils.py +++ b/examples/hstu/test_utils.py @@ -18,13 +18,15 @@ import commons.utils as init import configs -import dataset +import datasets import model import torch +from commons.distributed.sharding import apply_megatron_ddp, make_optimizer_and_shard +from commons.modules.embedding import ShardedEmbeddingConfig +from commons.optimizer import OptimizerParam from commons.utils.distributed_utils import collective_assert from commons.utils.hstu_assert_close import hstu_close -from configs import HSTULayerType, KernelBackend, OptimizerParam -from distributed.sharding import apply_megatron_ddp, make_optimizer_and_shard +from configs import HSTULayerType, KernelBackend from dynamicemb import DynamicEmbTableOptions from megatron.core import parallel_state, tensor_parallel from modules.debug.debug_hstu_layer import HSTULayer as DebugHSTULayer @@ -348,14 +350,14 @@ def create_model( item_emb_size = 1024 * 1024 action_vocab_size = 16 emb_configs = [ - configs.ShardedEmbeddingConfig( + ShardedEmbeddingConfig( feature_names=[action_feature_name], table_name="act", vocab_size=action_vocab_size, dim=embdim, sharding_type="data_parallel", ), - configs.ShardedEmbeddingConfig( + ShardedEmbeddingConfig( feature_names=[item_feature_name], table_name="item", vocab_size=item_emb_size, @@ -364,7 +366,7 @@ def create_model( ), ] feature_configs = [ - dataset.utils.FeatureConfig( + datasets.utils.FeatureConfig( feature_names=[item_feature_name, action_feature_name], max_item_ids=[ max(item_emb_size // 2, 1), @@ -376,7 +378,7 @@ def create_model( ] if len(contextual_feature_names) > 0: feature_configs.append( - dataset.utils.FeatureConfig( + datasets.utils.FeatureConfig( feature_names=contextual_feature_names, max_item_ids=[ contextual_emb_size for _ in range(len(contextual_feature_names)) @@ -386,7 +388,7 @@ def create_model( ) ) emb_configs.append( - configs.ShardedEmbeddingConfig( + ShardedEmbeddingConfig( feature_names=contextual_feature_names, table_name="context", vocab_size=contextual_emb_size, @@ -417,13 +419,13 @@ def create_model( with tensor_parallel.get_cuda_rng_tracker().fork(): if replicate_batches: history_batches = [ - dataset.utils.RankingBatch.random( + datasets.utils.RankingBatch.random( num_tasks=num_tasks, **batch_kwargs ) ] * num_batches else: history_batches = [ - dataset.utils.RankingBatch.random( + datasets.utils.RankingBatch.random( num_tasks=num_tasks, **batch_kwargs ) for _ in range(num_batches) @@ -439,11 +441,11 @@ def create_model( with tensor_parallel.get_cuda_rng_tracker().fork(): if replicate_batches: history_batches = [ - dataset.utils.RetrievalBatch.random(**batch_kwargs) + datasets.utils.RetrievalBatch.random(**batch_kwargs) ] * num_batches else: history_batches = [ - dataset.utils.RetrievalBatch.random(**batch_kwargs) + datasets.utils.RetrievalBatch.random(**batch_kwargs) for _ in range(num_batches) ] optimizer_param = OptimizerParam( diff --git a/examples/hstu/training/benchmark/hstu_layer_benchmark.py b/examples/hstu/training/benchmark/hstu_layer_benchmark.py index 1e2765ae3..4af9cb35e 100644 --- a/examples/hstu/training/benchmark/hstu_layer_benchmark.py +++ b/examples/hstu/training/benchmark/hstu_layer_benchmark.py @@ -35,6 +35,7 @@ import click import commons.utils.initialize as init import nvtx +from commons.ops.length_to_offsets import length_to_complete_offsets from commons.utils.gpu_timer import IGPUTimer from configs.hstu_config import ( HSTUConfig, @@ -46,7 +47,6 @@ from modules.fused_hstu_layer import FusedHSTULayer from modules.jagged_data import JaggedData from modules.native_hstu_layer import HSTULayer as NativeHSTULayer -from ops.length_to_offsets import length_to_complete_offsets from training.trainer.utils import cal_flops_single_rank _backend_str_to_type = { diff --git a/examples/hstu/training/pretrain_gr_ranking.py b/examples/hstu/training/pretrain_gr_ranking.py index 9ff0b32e2..2a592590b 100644 --- a/examples/hstu/training/pretrain_gr_ranking.py +++ b/examples/hstu/training/pretrain_gr_ranking.py @@ -23,17 +23,17 @@ import commons.utils.initialize as init import gin import torch # pylint: disable-unused-import +from commons.distributed.sharding import make_optimizer_and_shard +from commons.pipeline.train_pipeline import ( + JaggedMegatronPrefetchTrainPipelineSparseDist, + JaggedMegatronTrainNonePipeline, + JaggedMegatronTrainPipelineSparseDist, +) from commons.utils.logger import print_rank_0 from configs import RankingConfig -from distributed.sharding import make_optimizer_and_shard from megatron.core import parallel_state from model import get_ranking_model from modules.metrics import get_multi_event_metric_module -from pipeline.train_pipeline import ( - JaggedMegatronPrefetchTrainPipelineSparseDist, - JaggedMegatronTrainNonePipeline, - JaggedMegatronTrainPipelineSparseDist, -) from trainer.training import maybe_load_ckpts, train_with_pipeline from trainer.utils import ( create_dynamic_optitons_dict, diff --git a/examples/hstu/training/pretrain_gr_retrieval.py b/examples/hstu/training/pretrain_gr_retrieval.py index c628c5350..124e55145 100644 --- a/examples/hstu/training/pretrain_gr_retrieval.py +++ b/examples/hstu/training/pretrain_gr_retrieval.py @@ -23,15 +23,15 @@ import commons.utils.initialize as init import gin import torch # pylint: disable-unused-import -from configs import RetrievalConfig -from distributed.sharding import make_optimizer_and_shard -from model import get_retrieval_model -from modules.metrics import RetrievalTaskMetricWithSampling -from pipeline.train_pipeline import ( +from commons.distributed.sharding import make_optimizer_and_shard +from commons.pipeline.train_pipeline import ( JaggedMegatronPrefetchTrainPipelineSparseDist, JaggedMegatronTrainNonePipeline, JaggedMegatronTrainPipelineSparseDist, ) +from configs import RetrievalConfig +from model import get_retrieval_model +from modules.metrics import RetrievalTaskMetricWithSampling from trainer.training import maybe_load_ckpts, train_with_pipeline from trainer.utils import ( create_dynamic_optitons_dict, diff --git a/examples/hstu/training/trainer/training.py b/examples/hstu/training/trainer/training.py index a4b0254c8..91769878d 100644 --- a/examples/hstu/training/trainer/training.py +++ b/examples/hstu/training/trainer/training.py @@ -20,17 +20,17 @@ import torch # pylint: disable-unused-import import torch.distributed as dist from commons.checkpoint import get_unwrapped_module +from commons.pipeline.train_pipeline import ( + JaggedMegatronPrefetchTrainPipelineSparseDist, + JaggedMegatronTrainNonePipeline, + JaggedMegatronTrainPipelineSparseDist, +) from commons.utils.gpu_timer import GPUTimer from commons.utils.logger import print_rank_0 from commons.utils.stringify import stringify_dict from megatron.core import parallel_state from model import RankingGR, RetrievalGR from modules.metrics import RetrievalTaskMetricWithSampling -from pipeline.train_pipeline import ( - JaggedMegatronPrefetchTrainPipelineSparseDist, - JaggedMegatronTrainNonePipeline, - JaggedMegatronTrainPipelineSparseDist, -) from trainer.utils import cal_flops from utils import TrainerArgs diff --git a/examples/hstu/training/trainer/utils.py b/examples/hstu/training/trainer/utils.py index cecbe1ed1..9b87246f7 100644 --- a/examples/hstu/training/trainer/utils.py +++ b/examples/hstu/training/trainer/utils.py @@ -15,21 +15,20 @@ import sys from typing import Dict, List, Optional, Tuple, Union -import configs import dataset import torch # pylint: disable-unused-import import torch.distributed as dist +from commons.modules.embedding import ShardedEmbeddingConfig +from commons.optimizer import OptimizerParam from configs import ( HSTUConfig, HSTULayerType, HSTUPreprocessingConfig, KernelBackend, - OptimizerParam, PositionEncodingConfig, get_hstu_config, ) from dynamicemb import DynamicEmbTableOptions -from modules.embedding import ShardedEmbeddingConfig from utils import ( BenchmarkDatasetArgs, DatasetArgs, @@ -225,7 +224,7 @@ def get_data_loader( "retrieval", ], f"task type should be ranking or retrieval not {task_type}" if isinstance(dataset_args, BenchmarkDatasetArgs): - from dataset.utils import FeatureConfig + from datasets.utils import FeatureConfig assert ( trainer_args.max_train_iters is not None @@ -302,14 +301,14 @@ def create_embedding_config( hidden_size: int, embedding_args: EmbeddingArgs ) -> ShardedEmbeddingConfig: if isinstance(embedding_args, DynamicEmbeddingArgs): - return configs.ShardedEmbeddingConfig( + return ShardedEmbeddingConfig( feature_names=embedding_args.feature_names, table_name=embedding_args.table_name, vocab_size=embedding_args.item_vocab_size_or_capacity, dim=hidden_size, sharding_type="model_parallel", ) - return configs.ShardedEmbeddingConfig( + return ShardedEmbeddingConfig( feature_names=embedding_args.feature_names, table_name=embedding_args.table_name, vocab_size=embedding_args.item_vocab_size_or_capacity, diff --git a/examples/sid_gr/README.md b/examples/sid_gr/README.md new file mode 100644 index 000000000..ad8097b96 --- /dev/null +++ b/examples/sid_gr/README.md @@ -0,0 +1,93 @@ +# Semantic ID Generative Recommender Example + +## Getting Started + +- **Training**: See the [SID-GR training example](./training/README.md) for detailed instructions + +## Introduction + +**Semantic ID (SID)** based representation addresses the limitations of traditional item representations by tokenizing and quantizing items into a structured semantic space. The key innovation is that items with similar semantic meanings are mapped to nearby positions in the discrete ID space, creating a hierarchical and interpretable item vocabulary. This design offers several advantages: + +- **Semantic coherence**: Items with similar features or user preferences are assigned close semantic identifiers, enabling better generalization +- **Cold-start mitigation**: New items can be mapped to the semantic space based on their content features, reducing dependency on historical interactions +- **Generation efficiency**: With semantic IDs and optimized beam search implementations, the model can retrieve large numbers of candidates at the cost of only a few decoding steps +- **Scalability**: Hierarchical codebook structures (e.g., multi-level quantization) replace high-cardinality flat embedding tables, significantly reducing communication and storage resource requirements while enabling efficient representation of large item catalogs + +This example implements a Semantic ID based Generative Recommender (SID-GR) that combines the strengths of semantic item representations with powerful sequence modeling capabilities. The model backbone uses a standard self-attention decoder architecture, and we have integrated Megatron-Core to leverage its diverse parallelism capabilities. + +## Data Representation + +In this model, each unique PID (Product ID) is mapped to a fixed-length tuple of semantic identifiers. The number of hierarchies (i.e., tuple length) and the cardinality per hierarchy are determined by the user. To obtain semantic meanings, item information is encoded through an LLM into embeddings, followed by a quantization process. Quantization methods include RQ-KMeans, RQ-VAE, etc. See the diagram below: + +

+ SID GR overview +

+ +The mapping process can be handled offline and separately, decoupled from GR training. **This preprocessing step is not covered by this example.** Our work focuses solely on sequential GR training and inference. To ensure compatibility with previously processed sequential datasets, we save the processed PID-to-SID mapping as a PyTorch tensor file. During training, we load both the historical sequential dataset and the mapping tensor(s), performing on-the-fly conversion from PIDs to SIDs without any additional preprocessing of the historical dataset files. + +### PID-to-SID Tokenization + +We use [GRID](https://github.com/snap-research/GRID) to tokenize item product IDs into SID identifiers. After tokenization, the mapping tensor should have shape `[num_hierarchies, num_unique_items]`. To convert PID `p` to SIDs, simply index `mapping[:, p]`. This tensor is loaded by the dataloader. In cases where the number of unique items is extremely large, the mapping tensor can be chunked into multiple tensors. + +### Special Tokens + +In addition to normal SID tokens, a special `` (Beginning of Sequence) token is prepended to each item SID tuple when that item is involved in loss computation. This is performed during the model forward pass. + +**Example:** Given raw history item SIDs consisting of 3 items: `[s1, s2, s3; s4, s5, s6; s7, s8, s9]` + +- **Last item used for loss**: Transformed to `[s1, s2, s3; s4, s5, s6; bos, s7, s8, s9]` + - Using next-token prediction, tokens `bos, s7, s8` predict `s7, s8, s9` for cross-entropy loss computation + +- **Last 2 items used for loss**: Transformed to `[s1, s2, s3; bos, s4, s5, s6; bos, s7, s8, s9]` + +The diagram below illustrates the loss computation logic: + +

+ Loss Computation +

+ +## Embeddings + +Unlike traditional generative recommendation models that assign a unique embedding vector to each item (creating an extremely large and sparse embedding space), SID-based generative recommendation models only require multiple independent small tables. Since the vocabulary size of these tables typically ranges from a few hundred to a few thousand, we adopt a data-parallel strategy to distribute these tables. + +Specifically, we only need to create $\sum_{h \in H} C_{h}$ embedding vectors, where $C_{h}$ is the maximum capacity of hierarchy $h$. Both $H$ (number of hierarchies) and $C_{*}$ (capacities) are determined during the tokenization step. + +## Decoder Stack + +The model uses a standard Transformer decoder architecture, implemented using the Megatron-Core Transformer block for efficient parallel processing. + +## Prediction Head + +The prediction head is typically an MLP layer. Due to the hierarchical structure of SIDs, we support two configurations: + +1. **Shared prediction head**: A single head is shared across all hierarchies + - Training loss labels range from $0$ to $\sum_{h \in H} C_{h} - 1$ + +2. **Per-hierarchy prediction heads**: Each hierarchy has its own dedicated prediction head + - Tokens from the $h$-th hierarchy pass through the $h$-th prediction head + - Label range for each hierarchy: $0$ to $C_{h} - 1$ + +The choice between these two paradigms is controlled by [`NetworkArgs.share_lm_head_across_hierarchies`](./configs/sid_gin_config_args.py). + +## Beam Search Generation + +The SID-GR model performs retrieval through beam search generation. To retrieve $N$ candidates, the process involves $H$ steps of beam search, where the final step's beam width equals $N$. Compared to traditional LLMs, SID-GR has distinct characteristics: + +1. **Predetermined and small number of steps**: + - In LLMs, generation length is not predetermined and continues until certain criteria are met + - In SID-GR, the number of steps always equals the number of hierarchies ($H$), which is typically small (e.g., 3-5) + +2. **Much larger beam width**: + - LLMs use beam search primarily for diversity, typically with beam width < 10 + - Recommender systems require retrieving hundreds or thousands of candidates, necessitating much larger beam widths + +These two characteristics necessitate different performance optimization strategies compared to LLM inference. + + +## References + +- [Tiger: A Database System for Large-Scale Embedding Retrieval](https://arxiv.org/abs/2305.05065) +- [GRID: Generative Retrieval with Identifiers](https://arxiv.org/abs/2507.22224) +- [OneRec: A Generative Recommender](https://arxiv.org/abs/2506.13695) +- [OpenOneRec GitHub Repository](https://github.com/Kuaishou-OneRec/OpenOneRec) +- [T5 Model Documentation (Hugging Face)](https://huggingface.co/docs/transformers/v4.57.3/en/model_doc/t5#t5) \ No newline at end of file diff --git a/examples/sid_gr/beam_search/__init__.py b/examples/sid_gr/beam_search/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/sid_gr/beam_search/beam_search.py b/examples/sid_gr/beam_search/beam_search.py new file mode 100644 index 000000000..72274ff5b --- /dev/null +++ b/examples/sid_gr/beam_search/beam_search.py @@ -0,0 +1,153 @@ +from typing import List, Optional, Tuple, Union + +import torch + + +# TODO, make it graphable +class BeamSearch: + def __init__( + self, + beam_width: Union[int, List[int]], + num_hierarchies: int, + codebook_sizes: List[int], + codebooks: Optional[ + torch.Tensor + ] = None, # to check if the sid is mapped into the codebook + prefix_valid_check: bool = False, # to check if the prefix is valid + record_history: bool = False, + ): + """ + codebooks : [num_items, num_hierarchies] + """ + if isinstance(beam_width, int): + beam_widths = [beam_width] * num_hierarchies + self.beam_widths = beam_widths + self.num_hierarchies = num_hierarchies + self.codebook_sizes = codebook_sizes + assert ( + len(codebook_sizes) == num_hierarchies + ), "codebook_sizes should be the same length as num_hierarchies" + + if prefix_valid_check: + assert ( + codebooks is not None + ), "codebooks should be provided if prefix_valid_check is True" + self.accumulated_log_probs: torch.Tensor = torch.tensor( + [] + ) # to perceive the mppy check + self.generated_sids: torch.Tensor = torch.tensor( + [] + ) # to perceive the mppy check + self.step: int = 0 + + # for debugging purpose + self.record_history: bool = record_history + self.history_topk_sids: List[torch.Tensor] = [] + self.history_accumulate_topk_probs: List[torch.Tensor] = [] + self.history_probs: List[torch.Tensor] = [] + self.reset() + + def propagate( + self, + log_probs: torch.Tensor, # [batch_size, topk_previous_step, codebook_size[step]] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + In the beginning of step i, we have already generated sids [batchsize, beam_widths[i-1], i], + We will extend the generated sids [batchsize, beam_widths[i], i + 1] with the log_probs [batchsize, codebook_size[i]] + """ + step = self.step + if step >= self.num_hierarchies: + raise ValueError( + "Reached the last hierarchy, please call reset to start a new generation" + ) + batch_size, codebook_size_this_step = log_probs.shape[0], log_probs.shape[-1] + topk_previous_step = self.generated_sids.shape[1] if step > 0 else 1 + topk_this_step = min( + self.beam_widths[step], topk_previous_step * codebook_size_this_step + ) + + if step == 0: + # initialize the generated sids and accumulated log probs + self.generated_sids = torch.empty( + batch_size, + topk_previous_step, + step, + device=log_probs.device, + dtype=torch.long, + ) + self.accumulated_log_probs = torch.zeros( + batch_size, + topk_previous_step, + device=log_probs.device, + dtype=torch.float, + ) + + log_probs_this_step = log_probs.view( + batch_size, topk_previous_step, codebook_size_this_step + ) + accumulated_log_probs_this_step = ( + self.accumulated_log_probs.view(batch_size, topk_previous_step, 1) + + log_probs_this_step + ) + # [batch_size, topk_previous_step * codebook_size_this_step] + accumulated_log_probs_this_step = accumulated_log_probs_this_step.view( + batch_size, -1 + ) + topk_probs, topk_indices = torch.topk( + accumulated_log_probs_this_step, topk_this_step, dim=-1 + ) + current_step_sids = topk_indices % codebook_size_this_step + last_step_indices = topk_indices // codebook_size_this_step + # [batch_size, topk_this_step, step] + # it's safe to expand to zero when step is 0, + last_step_indices_expanded = last_step_indices.unsqueeze(-1).expand( + -1, -1, step + ) + last_step_sids = torch.gather( + self.generated_sids, dim=1, index=last_step_indices_expanded + ) + generated_sids = torch.cat( + [last_step_sids, current_step_sids.unsqueeze(-1)], dim=-1 + ) + if self.record_history: + self.history_topk_sids.append(generated_sids) + self.history_accumulate_topk_probs.append(torch.exp(topk_probs)) + self.history_probs.append(torch.exp(log_probs_this_step)) + self.generated_sids = generated_sids + self.accumulated_log_probs = topk_probs + self.step += 1 + # [[maybe discard]] + return generated_sids, topk_probs + + def reset(self): + self.generated_sids = None + self.accumulated_log_probs = None + self.step = 0 + self.history_topk_sids = [] + self.history_accumulate_topk_probs = [] + self.history_probs = [] + + def get_sids( + self, + step: Optional[int] = None, # [-1 ~ num_hierarchies) + ) -> torch.Tensor: + """ + return the generated sids at step i if step is valid, otherwise return None. + """ + if step is None: + return self.generated_sids + elif step == -1: + return None + elif step < self.step: + return self.generated_sids[:, :, step] + else: + raise ValueError(f"Step {step} is not valid, current step is {self.step}") + + def generate_valid_mask(self) -> torch.Tensor: + """ + update the valid mask between current step and previous step, + this can be used for transformer attention + """ + + def get_log_probs(self) -> torch.Tensor: + return self.accumulated_log_probs diff --git a/examples/sid_gr/configs/__init__.py b/examples/sid_gr/configs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/sid_gr/configs/args_to_config.py b/examples/sid_gr/configs/args_to_config.py new file mode 100644 index 000000000..9e68be5a3 --- /dev/null +++ b/examples/sid_gr/configs/args_to_config.py @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from commons.modules.embedding import ShardedEmbeddingConfig + +from .sid_gin_config_args import EmbeddingArgs + + +def create_embedding_config( + hidden_size: int, embedding_args: EmbeddingArgs +) -> ShardedEmbeddingConfig: + return ShardedEmbeddingConfig( + feature_names=embedding_args.feature_names, + table_name=embedding_args.table_name, + vocab_size=embedding_args.item_vocab_size_or_capacity, + dim=hidden_size, + sharding_type=embedding_args.sharding_type, + ) diff --git a/examples/sid_gr/configs/gpt_config.py b/examples/sid_gr/configs/gpt_config.py new file mode 100644 index 000000000..96b34fb8d --- /dev/null +++ b/examples/sid_gr/configs/gpt_config.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from enum import IntFlag, auto + +import torch +from megatron.core.transformer import TransformerConfig + + +class BOSMode(IntFlag): + HISTORY = auto() + CANDIDATE = auto() + ALWAYS = HISTORY | CANDIDATE + + +@dataclass +class GPTConfig(TransformerConfig): + bos_token_mode: BOSMode = BOSMode.CANDIDATE + + def __post_init__(self): + super().__post_init__() + + +def get_gpt_config( + hidden_size: int, + kv_channels: int, + num_attention_heads: int, + num_layers: int, + dtype: torch.dtype, + normalization: str = "RMSNorm", # "LayerNorm" or "RMSNorm" + norm_epsilon: float = 1e-5, + hidden_dropout=0.0, + tensor_model_parallel_size: int = 1, + loss_on_history: bool = False, +) -> GPTConfig: + """ + normalization: { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' + type of normalization applied. + """ + bos_token_mode = BOSMode.CANDIDATE + if loss_on_history: + bos_token_mode |= BOSMode.HISTORY + is_bf16 = dtype == torch.bfloat16 + is_fp16 = dtype == torch.float16 + return GPTConfig( # type: ignore + hidden_size=hidden_size, + num_layers=num_layers, + num_attention_heads=num_attention_heads, + kv_channels=kv_channels, + hidden_dropout=hidden_dropout, + attention_dropout=hidden_dropout, # TODO? + layernorm_epsilon=norm_epsilon, + bf16=is_bf16, + fp16=is_fp16, + tensor_model_parallel_size=tensor_model_parallel_size, + normalization=normalization, + bos_token_mode=bos_token_mode, + ) diff --git a/examples/sid_gr/configs/sid_amazn.gin b/examples/sid_gr/configs/sid_amazn.gin new file mode 100644 index 000000000..8e2c9fb64 --- /dev/null +++ b/examples/sid_gr/configs/sid_amazn.gin @@ -0,0 +1,35 @@ +TrainerArgs.train_batch_size = 128 +TrainerArgs.eval_batch_size = 128 +TrainerArgs.log_interval = 50 +TrainerArgs.eval_interval = 200 +TrainerArgs.profile = False +TrainerArgs.profile_step_start = 50 +TrainerArgs.profile_step_end = 80 +TrainerArgs.max_train_iters = 20000 +TrainerArgs.max_eval_iters = 700 # +TrainerArgs.pipeline_type = "none" + +TrainerArgs.top_k_for_generation = 20 # beam search width +TrainerArgs.eval_metrics = ('NDCG@10', 'Recall@10', 'HitRate@10') # evaluation metrics + +DatasetArgs.dataset_name = 'amzn_beauty' +DatasetArgs.max_history_length = 30 # item-wise +DatasetArgs.max_candidate_length = 0 # loss on history if 0, else only on candidate. +DatasetArgs.dataset_type_str = "disk_sequence_dataset" +DatasetArgs.sequence_features_training_data_path = "./tmp_data/amzn/beauty/training/training_22363.parquet" #"./tmp_data/amzn/beauty/training/22363.parquet" +DatasetArgs.sequence_features_testing_data_path = "./tmp_data/amzn/beauty/evaluation/eval_22363.parquet" # "./tmp_data/amzn/beauty/train_test_split/augmented_all_101605.parquet" # augmented_test_16996 +DatasetArgs.item_to_sid_mapping_path = "./tmp_data/amzn/beauty/item-sid-mapping.pt" +DatasetArgs.shuffle = False +DatasetArgs.num_hierarchies = 4 +DatasetArgs.codebook_sizes = [256, 256, 256, 256] # embedding vocab size for each hierarchy, the last one is used for de-duplication + +NetworkArgs.num_layers = 4 +NetworkArgs.num_attention_heads = 6 +NetworkArgs.hidden_size = 128 # embedding dim +# per head dim +NetworkArgs.kv_channels = 64 +NetworkArgs.share_lm_head_across_hierarchies = True + +OptimizerArgs.optimizer_str = 'adam' +OptimizerArgs.learning_rate = 1e-3 + diff --git a/examples/sid_gr/configs/sid_gin_config_args.py b/examples/sid_gr/configs/sid_gin_config_args.py new file mode 100644 index 000000000..8703396a4 --- /dev/null +++ b/examples/sid_gr/configs/sid_gin_config_args.py @@ -0,0 +1,287 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Optional, Tuple + +import gin + + +@gin.configurable +@dataclass +class TrainerArgs: + """Trainer Configuration. + + Training-related parameters and settings. + + Attributes: + train_batch_size (int): **Required**. Batch size per GPU. When TP is enabled, + the theoretical batch size is (train_batch_size × tp_size). + eval_batch_size (int): **Required**. Evaluation batch size. + eval_interval (int): Evaluation interval in iterations. Default: 100. + log_interval (int): Logging interval in iterations. Default: 100. + top_k_for_generation (int): Top K items to generate(retrieve) during evaluation. Default: 10. + eval_metrics (Tuple[str, ...]): Evaluation metrics (e.g., "HR@2", "NDCG@10"). + Default: ("HR@2", "NDCG@10"). + max_train_iters (Optional[int]): Maximum training iterations. Default: None. + max_eval_iters (Optional[int]): Maximum evaluation iterations. Default: None. + seed (int): Random seed. Default: 1234. + profile (bool): Enable profiling. Default: False. + profile_step_start (int): Profiling start step. Default: 100. + profile_step_end (int): Profiling end step. Default: 200. + ckpt_save_interval (int): Checkpoint save interval, -1 means no checkpoint saving. + Default: -1. + ckpt_save_dir (str): Checkpoint save directory. Default: "./checkpoints". + ckpt_load_dir (str): Checkpoint load directory. Default: "". + log_dir (str): Log directory. Default: "./logs". + pipeline_type (str): Pipeline overlap type: 'none' (no overlap), 'native' + (overlap h2d, input dist, fwd+bwd), 'prefetch' (includes prefetch overlap). + Default: "native". + """ + + # below batchsize is batchsize_per_gpu + # when TP is enabled, the theoratical batchsize is (train_batch_size * tp_size) + train_batch_size: int + eval_batch_size: int + + eval_interval: int = 100 + log_interval: int = 100 + + top_k_for_generation: int = 10 + eval_metrics: Tuple[str, ...] = field( + default_factory=lambda: ("HitRate@2", "NDCG@10") + ) + + max_train_iters: Optional[int] = None + max_eval_iters: Optional[int] = None + seed: int = 1234 + + # ==profile args== + profile: bool = False + profile_step_start: int = 100 + profile_step_end: int = 200 + # ==ckpt args== + ckpt_save_interval: int = -1 # -1 means not save ckpt + ckpt_save_dir: str = "./checkpoints" + ckpt_load_dir: str = "" + + # log_dir + log_dir: str = "./logs" + + # overlap pipeline type + # - none -> no overlap + # - native -> overlap [h2d, input dist, fwd+bwd] + # - prefetch -> overlap [h2d, input dist, prefetch, fwd+bwd] + pipeline_type: str = "native" # none, native, prefetch + + def __post_init__(self): + if isinstance(self.max_train_iters, str): + self.max_train_iters = int(self.max_train_iters) + for metric_spec in self.eval_metrics: + metric_name, top_k = metric_spec.split("@") + assert metric_name.lower() in [ + "ndcg", + "recall", + "hitrate", + ], "invalid metric name" + assert ( + int(top_k) <= self.top_k_for_generation + ), "top_k for evaluation should be less than top_k for generation" + + +@gin.configurable +@dataclass +class EmbeddingArgs: + """Embedding Configuration. + + Base embedding layer configuration parameters. + + Attributes: + feature_names (List[str]): **Required**. List of feature names. + table_name (str): **Required**. Embedding table name. + item_vocab_size_or_capacity (int): **Required**. For dynamic embedding: capacity; + for static embedding: vocabulary size. + sharding_type (str): Sharding type, must be "data_parallel" or "model_parallel". + Default: "None". + + Note: + A table could be only one of type `EmbeddingArgs`. + When movielen* or kuairand* datasets are used, `EmbeddingArgs` + are predefined. Setting the proper DatasetArgs.dataset_name in the gin config file will automatically set the proper EmbeddingArgs. + See `examples/sid_gr/data/sid_data_loader.py::get_train_and_test_data_loader()` for more details. + """ + + feature_names: List[str] + table_name: str + item_vocab_size_or_capacity: int + + sharding_type: str = "data_parallel" + + def __post_init__(self): + assert self.sharding_type.lower() in [ + "data_parallel", + "model_parallel", + ] + + +class DatasetType(Enum): + """ + Dataset type: + - InMemoryRandomDataset: in-memory random dataset, used for debugging and testing. + - DiskSequenceDataset: disk-based sequence dataset, used for training and evaluation. + """ + + InMemoryRandomDataset = "in_memory_random_dataset" + DiskSequenceDataset = "disk_sequence_dataset" + + +@gin.configurable +@dataclass +class DatasetArgs: + """Dataset Configuration. + + Dataset-related configuration parameters. + + Attributes: + dataset_name (str): **Required**. Dataset name. + max_history_length (int): **Required**. Maximum history length. + dataset_type (DatasetType): Dataset type. Default: DatasetType.InMemoryRandomDataset. + dataset_type_str (str): Dataset type string. Default: "in_memory_random_dataset". + sequence_features_training_data_path (Optional[str]): Path to training data. Default: None. + sequence_features_testing_data_path (Optional[str]): Path to testing data. Default: None. + shuffle (bool): Whether to shuffle data. Default: False. + item_to_sid_mapping_path (Optional[str]): Path to item to sid mapping. Default: None. + num_hierarchies (int): Number of hierarchies. Default: 4. + codebook_sizes (List[int]): Codebook sizes. Default: [500] * 4. + max_candidate_length (int): Maximum candidate length. Default: 1. + deduplicate_label_across_hierarchy (bool): Whether to deduplicate label across hierarchy. User should not set this explicitly. This is equal to share_lm_head_across_hierarchies. + """ + + dataset_name: str + max_history_length: int + dataset_type: DatasetType = DatasetType.InMemoryRandomDataset + dataset_type_str: str = "in_memory_random_dataset" + sequence_features_training_data_path: Optional[ + str + ] = None # None when dataset_type is InMemoryRandomDataset + sequence_features_testing_data_path: Optional[ + str + ] = None # None when dataset_type is InMemoryRandomDataset + shuffle: bool = False + + # below are used to describe the sid features + item_to_sid_mapping_path: Optional[ + str + ] = None # None when dataset_type is InMemoryRandomDataset or the dataset is already sid features + num_hierarchies: int = 4 + codebook_sizes: List[int] = field(default_factory=lambda: [500] * 4) + max_candidate_length: int = 1 + + # below are used to describe the sid features in the dataset batch + # and the embedding feature names should match the dataset batch feature names + _history_sid_feature_name: str = "hist_sids" + _candidate_sid_feature_name: str = "cand_sids" + deduplicate_label_across_hierarchy: bool = False + + def __post_init__(self): + assert ( + len(self.codebook_sizes) == self.num_hierarchies + ), "codebook_sizes should have the same length as num_hierarchies" + assert self.dataset_type_str.lower() in [ + "in_memory_random_dataset", + "disk_sequence_dataset", + ], "dataset_type_str should be in ['in_memory_random_dataset', 'disk_sequence_dataset']" + if self.dataset_type_str == "in_memory_random_dataset": + self.dataset_type = DatasetType.InMemoryRandomDataset + elif self.dataset_type_str == "disk_sequence_dataset": + self.dataset_type = DatasetType.DiskSequenceDataset + else: + raise ValueError(f"Invalid dataset type: {self.dataset_type_str}") + + +@gin.configurable +@dataclass +class NetworkArgs: + """Network Architecture Configuration. + + Neural network architecture parameters. + + Attributes: + num_layers (int): **Required**. Number of layers. + hidden_size (int): **Required**. Hidden layer size. + num_attention_heads (int): **Required**. Number of attention heads. + kv_channels (int): **Required**. Key-value channels. + hidden_dropout (float): Hidden layer dropout rate. Default: 0.2. + norm_epsilon (float): Normalization epsilon. Default: 1e-5. + is_causal (bool): Use causal attention mask. Default: True. + dtype_str (str): Data type: "bfloat16" or "float16". Default: "bfloat16". + share_lm_head_across_hierarchies (bool): Whether to share language model head + across hierarchies. Default: True. + """ + + num_layers: int + hidden_size: int + num_attention_heads: int + kv_channels: int + + hidden_dropout: float = 0.2 + norm_epsilon: float = 1e-5 + is_causal: bool = True + + dtype_str: str = "bfloat16" + share_lm_head_across_hierarchies: bool = True + + +@gin.configurable +@dataclass +class OptimizerArgs: + """Optimizer Configuration. + + Optimizer-related parameters. + + Attributes: + optimizer_str (str): **Required**. Optimizer name. + learning_rate (float): **Required**. Learning rate. + adam_beta1 (float): Adam optimizer beta1 parameter. Default: 0.9. + adam_beta2 (float): Adam optimizer beta2 parameter. Default: 0.999. + adam_eps (float): Adam optimizer epsilon parameter. Default: 1e-8. + weight_decay (float): Weight decay parameter. Default: 0.01. + """ + + optimizer_str: str + learning_rate: float + adam_beta1: float = 0.9 + adam_beta2: float = 0.999 + adam_eps: float = 1e-8 + weight_decay: float = 0.01 + + +@gin.configurable +@dataclass +class TensorModelParallelArgs: + """Tensor Model Parallelism Configuration. + + Tensor model parallelism settings. + + Attributes: + tensor_model_parallel_size (int): Tensor model parallel size (number of GPUs + for model sharding). Default: 1. + + Note: + The data parallel size is deduced based on the world_size and + tensor_model_parallel_size. + """ + + tensor_model_parallel_size: int = 1 diff --git a/examples/sid_gr/configs/sid_random.gin b/examples/sid_gr/configs/sid_random.gin new file mode 100644 index 000000000..ba51d0c23 --- /dev/null +++ b/examples/sid_gr/configs/sid_random.gin @@ -0,0 +1,31 @@ +TrainerArgs.train_batch_size = 32 +TrainerArgs.eval_batch_size = 32 +TrainerArgs.log_interval = 50 +TrainerArgs.eval_interval = 5000000 +TrainerArgs.profile = True +TrainerArgs.profile_step_start = 50 +TrainerArgs.profile_step_end = 80 +TrainerArgs.max_train_iters = 512 +TrainerArgs.max_eval_iters = 16 +TrainerArgs.pipeline_type = "none" + +DatasetArgs.dataset_name = 'sid_random' +DatasetArgs.max_history_length = 200 +DatasetArgs.dataset_type_str = "in_memory_random_dataset" +DatasetArgs.sequence_features_training_data_path = None +DatasetArgs.sequence_features_testing_data_path = None +DatasetArgs.item_to_sid_mapping_path = None +DatasetArgs.shuffle = False +DatasetArgs.num_hierarchies = 4 +DatasetArgs.codebook_sizes = [500, 500, 500, 500] # embedding vocab size for each hierarchy + +NetworkArgs.num_layers = 1 +NetworkArgs.num_attention_heads = 4 +NetworkArgs.hidden_size = 512 # embedding dim +# per head dim +NetworkArgs.kv_channels = 128 +NetworkArgs.dtype_str = "bfloat16" + +OptimizerArgs.optimizer_str = 'adam' +OptimizerArgs.learning_rate = 1e-3 + diff --git a/examples/sid_gr/datasets/__init__.py b/examples/sid_gr/datasets/__init__.py new file mode 100644 index 000000000..4af690321 --- /dev/null +++ b/examples/sid_gr/datasets/__init__.py @@ -0,0 +1,9 @@ +from .disk_sequence_dataset import DiskSequenceDataset +from .in_memory_random_dataset import InMemoryRandomDataset +from .sid_data_loader import get_train_and_test_data_loader + +__all__ = [ + "InMemoryRandomDataset", + "DiskSequenceDataset", + "get_train_and_test_data_loader", +] diff --git a/examples/sid_gr/datasets/dataset.py b/examples/sid_gr/datasets/dataset.py new file mode 100644 index 000000000..95f2e7ef9 --- /dev/null +++ b/examples/sid_gr/datasets/dataset.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from configs.sid_gin_config_args import DatasetArgs, DatasetType, TrainerArgs + +from .disk_sequence_dataset import DiskSequenceDataset +from .gpt_sid_batch import FeatureConfig +from .in_memory_random_dataset import InMemoryRandomDataset + + +def get_dataset( + dataset_args: DatasetArgs, + trainer_args: TrainerArgs, + is_train_dataset: bool, + rank: int = 0, + world_size: int = 1, + random_seed: int = 1234, +): + max_history_length = dataset_args.max_history_length + max_candidate_length = dataset_args.max_candidate_length + num_hierarchies = dataset_args.num_hierarchies + codebook_sizes = dataset_args.codebook_sizes + assert ( + len(codebook_sizes) == num_hierarchies + ), "codebook_sizes should have the same length as num_hierarchies" + if dataset_args.dataset_type == DatasetType.InMemoryRandomDataset: + # we need to use feature configs to generate random data + feature_configs = [] + raw_hist_sid_names = [f"hist_sid_{i}" for i in range(num_hierarchies)] + raw_cand_sid_names = [f"cand_sid_{i}" for i in range(num_hierarchies)] + # history sid features + feature_configs.append( + FeatureConfig( + feature_names=raw_hist_sid_names, + max_item_ids=[codebook_sizes[i] for i in range(num_hierarchies)], + max_sequence_length=max_history_length, + is_jagged=True, + ) + ) + # candidate sid features + feature_configs.append( + FeatureConfig( + feature_names=raw_cand_sid_names, + max_item_ids=[codebook_sizes[i] for i in range(num_hierarchies)], + max_sequence_length=max_candidate_length, + is_jagged=True, + ) + ) + # no contextual + return InMemoryRandomDataset.get_dataset( + batch_size=trainer_args.train_batch_size + if is_train_dataset + else trainer_args.eval_batch_size, + feature_configs=feature_configs, + raw_hist_sid_names=raw_hist_sid_names, + raw_cand_sid_names=raw_cand_sid_names, + combined_history_feature_name="hist_sids", + combined_candidate_feature_name="cand_sids", + contextual_feature_names=[], + num_generated_batches=1, + num_batches=trainer_args.max_train_iters + if is_train_dataset + else trainer_args.max_eval_iters, + ) + elif dataset_args.dataset_type == DatasetType.DiskSequenceDataset: + dataset_args.dataset_name + return DiskSequenceDataset.get_dataset( + raw_sequence_data_path=dataset_args.sequence_features_training_data_path + if is_train_dataset + else dataset_args.sequence_features_testing_data_path, + item_id_to_sid_mapping_tensor_path=dataset_args.item_to_sid_mapping_path, + batch_size=trainer_args.train_batch_size + if is_train_dataset + else trainer_args.eval_batch_size, + max_history_length=max_history_length, # +1 for the candidate + max_candidate_length=max_candidate_length + if is_train_dataset + else 1, # only 1 candidate item for eval. + raw_sequence_feature_name="sequence_data", # TODO: make it configurable!!! + num_hierarchies=num_hierarchies, + codebook_sizes=codebook_sizes, + output_history_sid_feature_name=dataset_args._history_sid_feature_name, + output_candidate_sid_feature_name=dataset_args._candidate_sid_feature_name, + rank=rank, + world_size=world_size, + shuffle=dataset_args.shuffle, + random_seed=random_seed, + is_train_dataset=is_train_dataset, + deduplicate_data_across_hierarchy=True, # deduplicate data because we are using single embedding tables + deduplicate_label_across_hierarchy=dataset_args.deduplicate_label_across_hierarchy, + ) + else: + raise ValueError(f"Invalid dataset type: {dataset_args.dataset_type}") diff --git a/examples/sid_gr/datasets/disk_sequence_dataset.py b/examples/sid_gr/datasets/disk_sequence_dataset.py new file mode 100644 index 000000000..c8eaadd62 --- /dev/null +++ b/examples/sid_gr/datasets/disk_sequence_dataset.py @@ -0,0 +1,331 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Iterator, List, Optional + +import numpy as np +import pandas as pd +import torch +from torch.utils.data.dataset import IterableDataset +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + +from .gpt_sid_batch import GPTSIDBatch + + +class DiskSequenceDataset(IterableDataset[GPTSIDBatch]): + """ + DiskSequenceDataset is an iterable dataset designed for sid-gr + """ + + def __init__( + self, + raw_sequence_data_path: str, + batch_size: int, # local batch size + max_history_length: int, # history seqlen + raw_sequence_feature_name: str, # 'sequence_data' + num_hierarchies: int, + codebook_sizes: List[int], + output_history_sid_feature_name: Optional[str] = None, + output_candidate_sid_feature_name: Optional[str] = None, + max_candidate_length: int = 1, # candidate seqlen + contextual_feature_names: List[str] = [], + item_id_to_sid_mapping_tensor_path: Optional[str] = None, + *, + rank: int, + world_size: int, + shuffle: bool, + random_seed: int, + is_train_dataset: bool, + deduplicate_data_across_hierarchy: bool = True, + deduplicate_label_across_hierarchy: bool = False, + sort_by_user_id: bool = True, # for debugging purpose + ): + # items and timestamps are nested + super().__init__() + if output_history_sid_feature_name is None: + output_history_sid_feature_name = f"hist_sids{num_hierarchies}d" + if output_candidate_sid_feature_name is None: + output_candidate_sid_feature_name = f"cand_sids{num_hierarchies}d" + self._device = torch.cpu.current_device() + raw_sequence_data = pd.read_parquet(raw_sequence_data_path) + if sort_by_user_id: + raw_sequence_data = raw_sequence_data.sort_values(by="user_id") + if "sequence_length" not in raw_sequence_data.columns: + raw_sequence_data["sequence_length"] = raw_sequence_data[ + raw_sequence_feature_name + ].apply(len) + assert ( + max_candidate_length <= 1 + ), "max_candidate_length should be less than or equal to 1 for now" + + # clamp the sequence length to 1 + max_candidate_length (at least 1 history item) + raw_sequence_data = raw_sequence_data[ + raw_sequence_data["sequence_length"] >= 1 + max_candidate_length + ] # at least 2 items in the sequence + + # truncate the sequence to the total sequence length + # note that max_history_length + max_candidate_length is the total sequence length + raw_sequence_data[raw_sequence_feature_name] = raw_sequence_data[ + raw_sequence_feature_name + ].apply( + lambda x: x[: max_history_length + max_candidate_length] + if isinstance(x, (list, np.ndarray)) + else x + ) + raw_sequence_data["sequence_length"] = raw_sequence_data[ + "sequence_length" + ].clip(upper=max_history_length + max_candidate_length) + self._feature_to_max_seqlen = { + output_history_sid_feature_name: max_history_length * num_hierarchies, + output_candidate_sid_feature_name: max_candidate_length * num_hierarchies, + } + try: + self.item_id_to_sid_mapping_tensor = torch.load( + item_id_to_sid_mapping_tensor_path + ) + + if not isinstance(self.item_id_to_sid_mapping_tensor, torch.Tensor): + raise TypeError("item_id_to_sid_mapping_tensor should be a tensor") + assert ( + self.item_id_to_sid_mapping_tensor.dim() == 2 + ), "item_id_to_sid_mapping_tensor should be a 2D tensor" + ( + mappping_num_hierarchies, + num_items, + ) = self.item_id_to_sid_mapping_tensor.shape + assert ( + mappping_num_hierarchies == num_hierarchies + ), "item_id_to_sid_mapping_tensor should have the same number of rows as num_hierarchies" + except: + raise RuntimeError("Failed to load item_id_to_sid_mapping_tensor") + + self._raw_sequence_data = raw_sequence_data + self._num_samples = raw_sequence_data.shape[0] + + self._raw_sequence_feature_name = raw_sequence_feature_name + self._output_history_sid_feature_name = output_history_sid_feature_name + self._output_candidate_sid_feature_name = output_candidate_sid_feature_name + self._num_hierarchies = num_hierarchies + + self._max_candidate_length = max_candidate_length + self._batch_size = batch_size + self._global_batch_size = batch_size * world_size + self._is_train_dataset = is_train_dataset + self._rank = rank + self._world_size = world_size + self._sample_ids = np.arange(self._num_samples) + codebook_offsets = torch.tensor( + np.cumsum([0] + codebook_sizes[:-1]), device=self._device + ) + # dedup data and label offsets + self.data_codebook_offsets = ( + codebook_offsets + if deduplicate_data_across_hierarchy + else torch.zeros(self._num_hierarchies, device=self._device) + ) + self.label_codebook_offsets = ( + codebook_offsets + if deduplicate_label_across_hierarchy + else torch.zeros(self._num_hierarchies, device=self._device) + ) + # TODO: Add shuffle and random seed + + def __iter__(self) -> Iterator[GPTSIDBatch]: + for i in range(len(self)): + local_batch_start = ( + i * self._global_batch_size + self._rank * self._batch_size + ) + local_batch_end = min( + i * self._global_batch_size + (self._rank + 1) * self._batch_size, + len(self._sample_ids), + ) + actual_batch_size = local_batch_end - local_batch_start + sample_ids = self._sample_ids[local_batch_start:local_batch_end] + sequence_data = self._raw_sequence_data.iloc[sample_ids] + # split history and candidate + # [1,2,| 3] => [1,2], [3] + # [1,2,3,4, | 5] => [1,2,3,4], [5] + # [1,2, |3] => [1,2], [3] + # candidate might be empty, so we need to handle it separately + history_item_ids = torch.tensor( + sequence_data[self._raw_sequence_feature_name] + .apply( + lambda x: x[: -self._max_candidate_length] + if self._max_candidate_length > 0 + else x + ) + .explode() + .to_numpy() + .astype(np.int64), + device=self._device, + ) + candidate_item_ids = ( + torch.tensor( + sequence_data[self._raw_sequence_feature_name] + .apply(lambda x: x[-self._max_candidate_length :]) + .explode() + .to_numpy() + .astype(np.int64), + device=self._device, + ) + if self._max_candidate_length > 0 + else None + ) + user_id = torch.tensor( + sequence_data["user_id"].to_numpy().astype(np.int64), + device=self._device, + ) + # add offset to the sids to avoid duplicate sids across hierarchy + # [T, num_hierarchies] + history_sids = torch.index_select( + self.item_id_to_sid_mapping_tensor, dim=1, index=history_item_ids + ).transpose(0, 1).contiguous() + self.data_codebook_offsets.unsqueeze(0) + # labels are the candidate sids but starting from 0. + candidate_sids = ( + ( + torch.index_select( + self.item_id_to_sid_mapping_tensor, + dim=1, + index=candidate_item_ids, + ) + .transpose(0, 1) + .contiguous() + ) + if self._max_candidate_length > 0 + else None + ) + + if self._max_candidate_length > 0: + labels = candidate_sids + self.label_codebook_offsets.unsqueeze(0) + else: + # we need to remove the starting sids for each sequence. + # TODO@junzhang, to optimize the redundant df operations and sid transformations. + label_item_ids = torch.tensor( + sequence_data[self._raw_sequence_feature_name] + .apply(lambda x: x[1:]) + .explode() + .to_numpy() + .astype(np.int64), + device=self._device, + ) + labels = ( + torch.index_select( + self.item_id_to_sid_mapping_tensor, dim=1, index=label_item_ids + ) + .transpose(0, 1) + .contiguous() + ) + self.label_codebook_offsets.unsqueeze(0) + + candidate_sids = ( + candidate_sids + self.data_codebook_offsets.unsqueeze(0) + if self._max_candidate_length > 0 + else None + ) + # 'sequence length' is the total length + history_lengths = ( + torch.tensor( + sequence_data["sequence_length"].to_numpy().astype(np.int64) + - self._max_candidate_length, + device=self._device, + dtype=torch.int64, + ) + * self._num_hierarchies + ) + candidate_lengths = ( + torch.ones(actual_batch_size, device=self._device, dtype=torch.int64) + * self._max_candidate_length + * self._num_hierarchies + ) + + def pad_tensor(padding_length: int, tensor: torch.Tensor) -> torch.Tensor: + if padding_length == 0: + return tensor + return torch.nn.functional.pad( + tensor, (0, padding_length), "constant", 0 + ) + + history_lengths = pad_tensor( + self._batch_size - actual_batch_size, history_lengths + ) + candidate_lengths = pad_tensor( + self._batch_size - actual_batch_size, candidate_lengths + ) + + batch_kwargs = dict( + features=KeyedJaggedTensor.from_lengths_sync( + keys=[ + self._output_history_sid_feature_name, + self._output_candidate_sid_feature_name, + ], + values=torch.cat([history_sids.view(-1), candidate_sids.view(-1)]) + if self._max_candidate_length > 0 + else history_sids.view(-1), + lengths=torch.cat([history_lengths, candidate_lengths]), + ), + batch_size=self._batch_size, + feature_to_max_seqlen=self._feature_to_max_seqlen, + _num_hierarchies=self._num_hierarchies, + history_feature_name=self._output_history_sid_feature_name, + candidate_feature_name=self._output_candidate_sid_feature_name, + labels=labels, # for eval, we need label to calculate metrics. + user_id=user_id, + actual_batch_size=actual_batch_size, + ) + yield GPTSIDBatch(**batch_kwargs) + + def __len__(self) -> int: + return math.ceil(self._num_samples / self._global_batch_size) + + @classmethod + def get_dataset( + cls, + raw_sequence_data_path: str, + item_id_to_sid_mapping_tensor_path: str, + batch_size: int, + max_history_length: int, + max_candidate_length: int, + raw_sequence_feature_name: str, + num_hierarchies: int, + codebook_sizes: List[int], + rank: int, + world_size: int, + shuffle: bool, + random_seed: int, + is_train_dataset: bool, + deduplicate_data_across_hierarchy: bool, + deduplicate_label_across_hierarchy: bool, + output_history_sid_feature_name: str, + output_candidate_sid_feature_name: str, + ) -> "DiskSequenceDataset": + return cls( + raw_sequence_data_path=raw_sequence_data_path, + item_id_to_sid_mapping_tensor_path=item_id_to_sid_mapping_tensor_path, + batch_size=batch_size, + max_history_length=max_history_length, + max_candidate_length=max_candidate_length, + raw_sequence_feature_name=raw_sequence_feature_name, + num_hierarchies=num_hierarchies, + codebook_sizes=codebook_sizes, + output_history_sid_feature_name=output_history_sid_feature_name, + output_candidate_sid_feature_name=output_candidate_sid_feature_name, + rank=rank, + world_size=world_size, + shuffle=shuffle, + random_seed=random_seed, + is_train_dataset=is_train_dataset, + deduplicate_data_across_hierarchy=deduplicate_data_across_hierarchy, + deduplicate_label_across_hierarchy=deduplicate_label_across_hierarchy, + ) diff --git a/examples/sid_gr/datasets/gpt_sid_batch.py b/examples/sid_gr/datasets/gpt_sid_batch.py new file mode 100644 index 000000000..dc7c75e0a --- /dev/null +++ b/examples/sid_gr/datasets/gpt_sid_batch.py @@ -0,0 +1,293 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple + +import torch +from megatron.core.packed_seq_params import PackedSeqParams +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.streamable import Pipelineable + + +def to_packed_seq_params( + cu_seqlens_q, + max_seqlen_q, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_kv: Optional[int] = None, +) -> PackedSeqParams: + cu_seqlens_kv = cu_seqlens_kv or cu_seqlens_q + max_seqlen_kv = max_seqlen_kv or max_seqlen_q + return PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens_q.to(torch.int32), + cu_seqlens_kv=cu_seqlens_kv.to(torch.int32), + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + ) + + +@dataclass +class FeatureConfig: + """ + + A FeatureConfig is a collection of features that share the same seqlen (also the same max_seqlence_length). + For example, an item id feature is mapped to [sid_0, sid_1, sid_2, sid_3] for 4 hierarchies. Those 4 features share one FeatureConfig. + Note that FeatureConfig is only used to generate random data. + + Attributes: + max_item_ids (List[int]): List of maximum item IDs for each feature. + max_history_length (int): The maximum length of sequences in the dataset. + is_jagged (bool): Whether the sequences are jagged (i.e., have varying lengths). + min_item_ids (List[int]): List of minimum item IDs for each feature. + feature_names (List[str]): List of feature names. + """ + + max_item_ids: List[int] # From dataset args + max_sequence_length: int + is_jagged: bool + feature_names: List[str] + + min_item_ids: List[int] = field(default_factory=list) + + def __post_init__(self): + if len(self.min_item_ids) == 0: + self.min_item_ids = [0] * len(self.max_item_ids) + else: + assert len(self.min_item_ids) == len( + self.max_item_ids + ), "min_item_ids should have the same length as max_item_ids" + assert len(self.feature_names) == len( + self.max_item_ids + ), "feature_names should have the same length as max_item_ids" + + +@dataclass +class GPTSIDBatch(Pipelineable): + # TODO: check if candidates are always dense. + features: KeyedJaggedTensor # contextual features, user history features, candidate features + batch_size: int + feature_to_max_seqlen: Dict[str, int] + # currently we do not have contextual features. + contextual_feature_names: List[str] = field(default_factory=lambda: []) + raw_hist_sid_names: List[str] = field( + default_factory=lambda: [] + ) # all those features compose history_feature_name, this is used for random generation + raw_cand_sid_names: List[str] = field( + default_factory=lambda: [] + ) # all those features compose history_feature_name, this is used for random generation + + history_feature_name: str = ( + "history_sequence" # raw sid features are combined into this feature. + ) + candidate_feature_name: str = ( + "candidate_sequence" # raw sid features are combined into this feature. + ) + _num_hierarchies: int = 4 + user_id: Optional[torch.Tensor] = None + labels: Optional[ + torch.Tensor + ] = None # For retrieval, candidates are labels! Inference batch does not have labels. + + actual_batch_size: Optional[int] = None # in case of padding + + def to(self, device: torch.device, non_blocking: bool = True) -> "GPTSIDBatch": # type: ignore + return GPTSIDBatch( + features=self.features.to(device=device, non_blocking=non_blocking), + batch_size=self.batch_size, + actual_batch_size=self.actual_batch_size, + feature_to_max_seqlen=self.feature_to_max_seqlen, + contextual_feature_names=self.contextual_feature_names, + raw_hist_sid_names=self.raw_hist_sid_names, + raw_cand_sid_names=self.raw_cand_sid_names, + labels=self.labels.to(device=device, non_blocking=non_blocking) + if self.labels is not None + else None, + history_feature_name=self.history_feature_name, + candidate_feature_name=self.candidate_feature_name, + _num_hierarchies=self._num_hierarchies, + user_id=self.user_id.to(device=device, non_blocking=non_blocking) + if self.user_id is not None + else None, + ) + + def record_stream(self, stream: torch.cuda.Stream): + self.features.record_stream(stream) + if self.labels is not None: + self.labels.record_stream(stream) + + def retain_candidate_hierarchies( + self, + remained_hierarchies: int, + ) -> "GPTSIDBatch": + candidate_jt = self.features[self.candidate_feature_name] + original_hierarchies = self._num_hierarchies + assert ( + original_hierarchies >= remained_hierarchies + ), "remained_hierarchies should be less than or equal to original_hierarchies" + candidate_lengths = candidate_jt.lengths() - ( + original_hierarchies - remained_hierarchies + ) + candidate_features = ( + candidate_jt.values() + .view(-1, original_hierarchies)[:, :remained_hierarchies] + .reshape(-1) + ) + if self.labels is not None: + labels = self.labels[:, :remained_hierarchies] + history_jt = self.features[self.history_feature_name] + history_lengths = history_jt.lengths() + history_features = history_jt.values() + + kjt = KeyedJaggedTensor.from_lengths_sync( + keys=[ + self.history_feature_name, + self.candidate_feature_name, + ], + values=torch.cat([history_features, candidate_features]), + lengths=torch.cat([history_lengths, candidate_lengths]), + ) + new_batch = GPTSIDBatch( + features=kjt, + batch_size=self.batch_size, + feature_to_max_seqlen=self.feature_to_max_seqlen, + contextual_feature_names=self.contextual_feature_names, + raw_hist_sid_names=self.raw_hist_sid_names, + raw_cand_sid_names=self.raw_cand_sid_names, + _num_hierarchies=remained_hierarchies, + history_feature_name=self.history_feature_name, + candidate_feature_name=self.candidate_feature_name, + user_id=self.user_id, + labels=labels, + ) + return new_batch + + @staticmethod + def random( + batch_size: int, + feature_configs: List[ + FeatureConfig + ], # hist and cand share the same feature config. + raw_hist_sid_names: List[str], + raw_cand_sid_names: List[str], + contextual_feature_names: List[str], + *, + combined_history_feature_name: str = "history_sequence", + combined_candidate_feature_name: str = "candidate_sequence", + device: torch.device, + ) -> "GPTSIDBatch": + feature_name_kvl: Dict[str, Tuple[torch.Tensor, torch.Tensor, int]] = {} + keys = [] + values = [] + lengths = [] + feature_to_max_seqlen = {} + sid_min_ids = [] + for feature_config in feature_configs: + if feature_config.is_jagged: + seqlen = torch.randint( + feature_config.max_sequence_length, (batch_size,), device=device + ) + # the random guarantee the sequence length is at least 1. + # when candidate + seqlen = seqlen.clamp(min=1) + else: + seqlen = torch.full( + (batch_size,), feature_config.max_sequence_length, device=device + ) + total_seqlen = torch.sum(seqlen).item() + feature_names = feature_config.feature_names + max_item_ids = feature_config.max_item_ids + min_item_ids = feature_config.min_item_ids + assert ( + len(feature_names) == len(max_item_ids) == len(min_item_ids) + ), "feature_names, max_item_ids, and min_item_ids should have the same length" + for i in range(len(feature_names)): + key = feature_names[i] + value = torch.randint( + min_item_ids[i], + max_item_ids[i], + (total_seqlen,), + device=device, + ) + feature_name_kvl[key] = ( + value, + seqlen, + feature_config.max_sequence_length, + ) + # we use candidate + if key in raw_cand_sid_names: + sid_min_ids.append(min_item_ids[i]) + + history_sid_kvl = {key: feature_name_kvl.pop(key) for key in raw_hist_sid_names} + candidate_sid_kvl = { + key: feature_name_kvl.pop(key) for key in raw_cand_sid_names + } + feature_name_kvl.update( + { + combined_history_feature_name: ( + torch.stack([v[0] for v in history_sid_kvl.values()], dim=1).view( + -1 + ), + torch.sum( + torch.stack([v[1] for v in history_sid_kvl.values()], dim=1), + dim=1, + ).view(-1), + sum(v[2] for v in history_sid_kvl.values()), + ), + combined_candidate_feature_name: ( + torch.stack([v[0] for v in candidate_sid_kvl.values()], dim=1).view( + -1 + ), + torch.sum( + torch.stack([v[1] for v in candidate_sid_kvl.values()], dim=1), + dim=1, + ).view(-1), + sum(v[2] for v in candidate_sid_kvl.values()), + ), + } + ) + num_hierarchies = len(raw_hist_sid_names) + assert num_hierarchies == len( + raw_cand_sid_names + ), "number of hierarchies should be the same as the number of candidate sid feature names" + keys = list(feature_name_kvl.keys()) + values = [feature_name_kvl[key][0] for key in keys] + lengths = [feature_name_kvl[key][1] for key in keys] + feature_to_max_seqlen = {key: feature_name_kvl[key][2] for key in keys} + features = KeyedJaggedTensor.from_lengths_sync( + keys=keys, + values=torch.cat(values).to(device), + lengths=torch.cat(lengths).to(device).long(), + ) + + sid_min_ids = torch.tensor(sid_min_ids, device=device).unsqueeze(0) + #!! labels are the candidate sids but starting from 0. + labels = ( + features[combined_candidate_feature_name].values().view(-1, num_hierarchies) + - sid_min_ids + ) + return GPTSIDBatch( + features=features, + labels=labels, + batch_size=batch_size, + actual_batch_size=batch_size, + feature_to_max_seqlen=feature_to_max_seqlen, + raw_hist_sid_names=raw_hist_sid_names, + raw_cand_sid_names=raw_cand_sid_names, + history_feature_name=combined_history_feature_name, + candidate_feature_name=combined_candidate_feature_name, + contextual_feature_names=contextual_feature_names, + _num_hierarchies=num_hierarchies, + user_id=None, + ) diff --git a/examples/sid_gr/datasets/in_memory_random_dataset.py b/examples/sid_gr/datasets/in_memory_random_dataset.py new file mode 100644 index 000000000..a3747dfff --- /dev/null +++ b/examples/sid_gr/datasets/in_memory_random_dataset.py @@ -0,0 +1,119 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +from typing import Iterator, List, Optional, cast + +import fbgemm_gpu # pylint: disable-unused-import +import torch +from torch.utils.data.dataset import IterableDataset + +from .gpt_sid_batch import FeatureConfig, GPTSIDBatch + + +class InMemoryRandomDataset(IterableDataset[GPTSIDBatch]): + """ + InMemoryRandomDataset is an iterable dataset for generating random batches of data. + + Args: + batch_size (int): The batchsize per rank. + feature_configs (List[FeatureConfig]): A list of configurations for different features. + item_feature_name (str): The name of the item feature. + contextual_feature_names (List[str], optional): A list of names for contextual features. Defaults to an empty list. + action_feature_name (Optional[str], optional): The name of the action feature. Defaults to ``None``. + max_num_candidates (Optional[int], optional): The maximum number of candidates. Defaults to 0. + num_generated_batches (int, optional): The number of batches to generate. Defaults to 1. + num_tasks (int, optional): The number of tasks. Defaults to 0. + num_batches (bool, optional): The total number of batches to iterate over. Defaults to ``None``. + + """ + + def __init__( + self, + batch_size: int, + feature_configs: List[ + FeatureConfig + ], # we need feature config for random generation + raw_hist_sid_names: List[str], + raw_cand_sid_names: List[str], + combined_history_feature_name: str, + combined_candidate_feature_name: str, + contextual_feature_names: List[str] = [], + num_generated_batches=1, + num_batches: Optional[int] = None, + ): + super().__init__() + self.num_batches: int = cast( + int, num_batches if num_batches is not None else sys.maxsize + ) + self._cached_batched: List[GPTSIDBatch] = [] + self._num_generated_batches = num_generated_batches + kwargs = dict( + batch_size=batch_size, + feature_configs=feature_configs, + raw_hist_sid_names=raw_hist_sid_names, + raw_cand_sid_names=raw_cand_sid_names, + combined_history_feature_name=combined_history_feature_name, + combined_candidate_feature_name=combined_candidate_feature_name, + contextual_feature_names=contextual_feature_names, + device=torch.cpu.current_device(), + ) + for _ in range(self._num_generated_batches): + self._cached_batched.append(GPTSIDBatch.random(**kwargs)) + self._iloc = 0 + + def __iter__(self) -> Iterator[GPTSIDBatch]: + """ + Returns an iterator over the cached batches, cycling through them. + + Returns: + Union[RankingBatch, RetrievalBatch]: The next batch in the cycle. + """ + for _ in range(len(self)): + yield self._cached_batched[self._iloc] + self._iloc = (self._iloc + 1) % self._num_generated_batches + + def __len__(self) -> int: + """ + Get the number of batches. + + Returns: + int: The number of batches. + """ + return self.num_batches + + @classmethod + def get_dataset( + cls, + batch_size: int, + feature_configs: List[FeatureConfig], + raw_hist_sid_names: List[str], + raw_cand_sid_names: List[str], + combined_history_feature_name: str, + combined_candidate_feature_name: str, + contextual_feature_names: List[str] = [], + num_generated_batches: int = 1, + num_batches: Optional[int] = None, + ): + return cls( + batch_size=batch_size, + feature_configs=feature_configs, + raw_hist_sid_names=raw_hist_sid_names, + raw_cand_sid_names=raw_cand_sid_names, + combined_history_feature_name=combined_history_feature_name, + combined_candidate_feature_name=combined_candidate_feature_name, + contextual_feature_names=contextual_feature_names, + num_generated_batches=num_generated_batches, + num_batches=num_batches, + ) diff --git a/examples/sid_gr/datasets/preprocessor.py b/examples/sid_gr/datasets/preprocessor.py new file mode 100644 index 000000000..3e47e95d2 --- /dev/null +++ b/examples/sid_gr/datasets/preprocessor.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import sys +import warnings +from abc import ABC +from typing import List + +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +log = logging.getLogger("main") +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + + +class DataProcessor(ABC): + def __init__(self): + pass + + @property + def history_feature_name(self): + pass + + @property + def candidate_feature_name(self): + pass + + @property + def contextual_feature_names(self): + pass + + +# class AmazonPreprocessor(DataProcessor): +# def __init__(self): +# super().__init__() +# pass + + +class AmazonBeautyPreprocessor(DataProcessor): + def __init__(self, sequence_features_training_data_path: str): + super().__init__() + self._num_hierarchies = 4 + + @property + def num_hierarchies(self) -> int: + return self._num_hierarchies + + @property + def history_feature_name(self) -> str: + return "history_sequence" + + @property + def candidate_feature_name(self) -> str: + return "candidate_sequence" + + @property + def contextual_feature_names(self) -> List[str]: + return [] + + @property + def sequence_is_sid(self) -> bool: + return False + + @property + def raw_sequence_feature_name(self) -> str: + """ + a raw sequence is split into history and candidate sequences. + """ + return "item_ids" + + +def get_common_preprocessors(sequence_features_training_data_path: str): + return { + "amazon_beauty": AmazonBeautyPreprocessor(sequence_features_training_data_path), + } diff --git a/examples/sid_gr/datasets/sid_data_loader.py b/examples/sid_gr/datasets/sid_data_loader.py new file mode 100644 index 000000000..b949851ae --- /dev/null +++ b/examples/sid_gr/datasets/sid_data_loader.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from configs.sid_gin_config_args import DatasetArgs, TrainerArgs +from torch.distributed import get_rank, get_world_size +from torch.utils.data import DataLoader + +from .dataset import get_dataset + + +def _get_train_and_test_data_loader_from_dataset( + dataset: torch.utils.data.Dataset, + pin_memory: bool = False, +) -> DataLoader: + return DataLoader( + dataset, + batch_size=None, + batch_sampler=None, + pin_memory=pin_memory, + collate_fn=lambda x: x, + ) + + +def get_train_and_test_data_loader( + dataset_args: DatasetArgs, + trainer_args: TrainerArgs, +): + train_dataset = get_dataset( + dataset_args, + trainer_args, + is_train_dataset=True, + rank=get_rank(), + world_size=get_world_size(), + random_seed=trainer_args.seed, + ) + eval_dataset = get_dataset( + dataset_args, + trainer_args, + is_train_dataset=False, + rank=get_rank(), + world_size=get_world_size(), + random_seed=trainer_args.seed, + ) + + train_loader = _get_train_and_test_data_loader_from_dataset(train_dataset) + eval_loader = _get_train_and_test_data_loader_from_dataset(eval_dataset) + + return train_loader, eval_loader diff --git a/examples/sid_gr/figs/sid loss.png b/examples/sid_gr/figs/sid loss.png new file mode 100644 index 000000000..06f54604c Binary files /dev/null and b/examples/sid_gr/figs/sid loss.png differ diff --git a/examples/sid_gr/figs/sid-gr scope.png b/examples/sid_gr/figs/sid-gr scope.png new file mode 100644 index 000000000..f8c2d95bf Binary files /dev/null and b/examples/sid_gr/figs/sid-gr scope.png differ diff --git a/examples/sid_gr/model/__init__.py b/examples/sid_gr/model/__init__.py new file mode 100644 index 000000000..46499783e --- /dev/null +++ b/examples/sid_gr/model/__init__.py @@ -0,0 +1,40 @@ +from typing import List, Optional, Tuple + +from commons.modules.embedding import ShardedEmbeddingConfig +from megatron.core.transformer import TransformerConfig + +from .gpt_model import SIDGRModel +from .mcore_model_specs import get_gpt_decoder_block_spec + +__all__ = ["get_sid_gr_model"] + + +def get_sid_gr_model( + decoder_config: TransformerConfig, + codebook_embedding_config: ShardedEmbeddingConfig, + codebook_sizes: List[int], + num_hierarchies: int, + normalization: Optional[str] = None, + top_k_for_generation: int = 10, + eval_metrics: Tuple[str, ...] = (), + share_lm_head_across_hierarchies: bool = True, +) -> SIDGRModel: + sid_gr_model = SIDGRModel( + decoder_config=decoder_config, + codebook_embedding_config=codebook_embedding_config, + codebook_sizes=codebook_sizes, + num_hierarchies=num_hierarchies, + transformer_decoder_layer_spec=get_gpt_decoder_block_spec( + # padding + arbitrary attention mask + Megatron-Core + decoder_config, + use_transformer_engine=False, + arbitrary_attention_mask=True, + normalization=normalization, + ), + should_add_sep_token=False, + top_k_for_generation=top_k_for_generation, + eval_metrics=eval_metrics, + share_lm_head_across_hierarchies=share_lm_head_across_hierarchies, + ) + + return sid_gr_model diff --git a/examples/sid_gr/model/attention_mask.py b/examples/sid_gr/model/attention_mask.py new file mode 100644 index 000000000..63234f6df --- /dev/null +++ b/examples/sid_gr/model/attention_mask.py @@ -0,0 +1,231 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +@torch.fx.wrap +def padded_causal_mask_with_optional_bos( + input_offsets: torch.Tensor, + input_max_seqlen: int, + add_bos_to_history: bool = False, + bos_interval: int = 0, +) -> torch.Tensor: + B = input_offsets.size(0) - 1 + S = input_max_seqlen + + # bs, num_head, seq, seq + lower_triangle_mask = torch.tril( + torch.ones( + (B, 1, S, S), + dtype=torch.bool, + device=torch.cuda.current_device(), + ) + ) + if add_bos_to_history: + num_hierarchies_with_bos = bos_interval + 1 + # [[{s0,s1,s2| bos, s3,s4,s5| bos, s6,s7,s8| bos, ..., s_{3N-1}}, {bos, c0,c1,c2}], [{s3,s4,s5| bos, s6,s7,s8| bos, ..., s_{3M-1}}, {bos, c4,c5,c6}]] + assert ( + S + 1 + ) % num_hierarchies_with_bos == 0, ( + "input_max_seqlen + 1 should be divisible by bos_interval + 1" + ) + + # later history tokens can't attend to previous bos tokens + bos_row_ids = torch.arange( + 0, S, device=input_offsets.device, dtype=input_offsets.dtype + ).view(-1, 1) + bos_col_ids = torch.arange( + 0, S, device=input_offsets.device, dtype=input_offsets.dtype + ).view(1, -1) + bos_col_mask = (bos_col_ids + 1) % num_hierarchies_with_bos == 0 + bos_col_mask = bos_col_mask & ( + bos_row_ids >= bos_col_ids + num_hierarchies_with_bos + ) + lower_triangle_mask = lower_triangle_mask & ~bos_col_mask + # bos_row_ids = bos_row_ids[bos_row_ids % (num_hierarchies + 1) == 0] * (num_hierarchies + 1) + else: + # [[{item0, item1, item2, ..., itemN}, {bos}], [{item3, item4, item5, ..., itemM}, {bos}]] + # it's causal + pass + # we set the bos + # broadcast num_head, s_kv + mask = ( + torch.ops.fbgemm.jagged_to_padded_dense( + values=torch.ones(size=(input_offsets[-1],)).cuda(), + offsets=[input_offsets], + max_lengths=[input_max_seqlen], + ) + .unsqueeze(1) + .unsqueeze(-1) + ) + jagged_causal_mask = torch.logical_and( + lower_triangle_mask, + mask, + ) + # note that we return the inverse of the mask to match the attention mask format. + return ~jagged_causal_mask + + +@torch.fx.wrap +def padded_history_mask_with_causal_target( + history_seqlen: torch.Tensor, + history_max_seqlen: int, + target_seqlen: int, + history_causal: bool = True, +) -> torch.Tensor: + """ + generate a mask for the history and the causal target. + For history, we pretend it's an encoder, while for target, we pretend it's a decoder. + Args: + history_offsets: [batchsize + 1] + history_max_seqlen: int + target_offsets: [batchsize + 1] + target_max_seqlen: int + Returns: + mask: [batchsize, 1, history_max_seqlen, history_max_seqlen + target_max_seqlen] + + + Example: + + history_max_seqlen = 4 + target_seqlen = 3 + history_offsets = [0, 4, 6] + [[a0,a1,a2,a3,c0,c1,c2], [a4,a5,c3,c4,c5]] + mask: + [ + [ [1,1,1,1,0,0,0], + [1,1,1,1,0,0,0], + [1,1,1,1,0,0,0], + [1,1,1,1,0,0,0]] + [1,1,1,1,1,0,0]] + [1,1,1,1,1,1,0]] + [1,1,1,1,1,1,1]], + + [ [1,1,0,0,0,0,0], + [1,1,0,0,0,0,0], + [1,1,1,0,0,0,0], + [1,1,1,1,0,0,0], + [1,1,1,1,1,0,0], + [0,0,0,0,0,0,0], + [0,0,0,0,0,0,0], + ] + """ + device = history_seqlen.device + # [B,1,1] + valid_lengths = (history_seqlen + target_seqlen).view(-1, 1, 1) + N = history_max_seqlen + target_seqlen + ids = torch.arange(0, N, device=device).view(1, N) + # [1,N,N] + row_ids = ids.unsqueeze(-1).expand(-1, N, N) + col_ids = row_ids.transpose(1, 2) + row_col_dist = row_ids - col_ids + valid_attn_mask = torch.eye(N, device=device, dtype=torch.bool).view(1, N, N) + causal_mask = torch.logical_or(row_col_dist > 0, valid_attn_mask) + history_and_target_mask = torch.logical_and( + row_ids < valid_lengths.view(-1, 1, 1), col_ids < valid_lengths.view(-1, 1, 1) + ) + if not history_causal: + history_mask = torch.logical_and( + row_ids < history_seqlen.view(-1, 1, 1), + col_ids < history_seqlen.view(-1, 1, 1), + ) + history_upper_mask = torch.logical_and(history_mask, row_ids < col_ids) + causal_mask = causal_mask | history_upper_mask + valid_attn_mask = history_and_target_mask & causal_mask + # [B, 1, N, N] for num_head attention + valid_attn_mask = valid_attn_mask.unsqueeze(1) + return ~valid_attn_mask + + +# refer to hstu https://github.com/jiayus-nvidia/FBGEMM/blob/main/fbgemm_gpu/experimental/hstu/img/context_causal_target.png +def padded_target_aware_causal_mask( + history_seqlen: torch.Tensor, + max_history_seqlen: int, + num_target_region: int, + target_max_seqlen_per_region: int, + causal: bool = True, +) -> torch.Tensor: + """ + Used for the beam search where there are multiple beams (targets) for each history. + input sequence is : [history, target_region_0, target_region_1, ... padding_0, padding_1, ...], + where history length is history_seqlen, each target region length is target_max_seqlen_per_region, + and padding length is (max_history_seqlen - history_seqlen). + intra region: causal ; inter region: invisible. + each target needs to attend to the history + + """ + device = history_seqlen.device + target_lengths = target_max_seqlen_per_region * num_target_region + N = max_history_seqlen + target_lengths + valid_lengths = (history_seqlen + target_lengths).view(-1, 1, 1) + + ids = torch.arange(0, N, device=device).view(1, N) + # [B,1,1] + row_ids = ids.unsqueeze(-1).expand(-1, N, N) + col_ids = row_ids.transpose(1, 2) + row_col_dist = row_ids - col_ids + valid_attn_mask = torch.eye(N, device=device, dtype=torch.bool).view(1, N, N) + + valid_region_mask = torch.logical_and( + row_ids < valid_lengths.view(-1, 1, 1), col_ids < valid_lengths.view(-1, 1, 1) + ) + if not causal: + row_col_dist = torch.where(row_col_dist > 0, row_col_dist, -row_col_dist) + valid_attn_mask = torch.logical_or(valid_attn_mask, row_col_dist > 0) + if num_target_region > 0: + target_group_row_ids = ( + torch.clamp(row_ids - valid_lengths + target_lengths, min=-1) + // target_max_seqlen_per_region + ) + target_group_col_ids = target_group_row_ids.transpose(1, 2) + target_dist = target_group_row_ids - target_group_col_ids + + target_group_mask = torch.logical_or( + target_dist == 0, (target_group_row_ids < 0) + (target_group_col_ids < 0) + ) + # preserve the intra-target-group attention and purge the inter-target-group attention + valid_attn_mask = torch.logical_and(valid_attn_mask, target_group_mask) + + # [B, N, N] + valid_attn_mask = valid_attn_mask & valid_region_mask + + # [B, 1, N, N] for num_head attention + valid_attn_mask = valid_attn_mask.unsqueeze(1) + # note that we return the inverse of the mask to match the attention mask format. + return ~valid_attn_mask + + +if __name__ == "__main__": + history_seqlen = torch.tensor([4, 3]).cuda() + max_history_seqlen = 6 + num_target_region = 3 + target_max_seqlen_per_region = 3 + device = torch.device("cuda") + history_causal = False + mask = padded_target_aware_causal_mask( + history_seqlen, + max_history_seqlen, + num_target_region, + target_max_seqlen_per_region, + history_causal, + ) + valid_mask = ~mask + + mask = padded_history_mask_with_causal_target( + history_seqlen, + max_history_seqlen, + target_max_seqlen_per_region, + history_causal, + ) diff --git a/examples/sid_gr/model/gpt_model.py b/examples/sid_gr/model/gpt_model.py new file mode 100644 index 000000000..75a03fcbd --- /dev/null +++ b/examples/sid_gr/model/gpt_model.py @@ -0,0 +1,778 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +import torch +from beam_search.beam_search import BeamSearch +from commons.modules.embedding import ShardedEmbedding, ShardedEmbeddingConfig +from commons.ops.cuda_ops.JaggedTensorOpFunction import jagged_2D_tensor_concat +from commons.ops.length_to_offsets import length_to_complete_offsets +from commons.ops.triton_ops.triton_jagged import triton_split_2D_jagged +from configs.gpt_config import BOSMode +from datasets.gpt_sid_batch import GPTSIDBatch, to_packed_seq_params +from megatron.core.enums import ModelType +from megatron.core.extensions.transformer_engine import TEColumnParallelLinear +from megatron.core.models.common.embeddings.relative_pos_embedding import ( + RelativePositionEmbedding, +) +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer import TransformerConfig +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_block import TransformerBlock +from modules.eval_metrics import SIDRetrievalEvaluator +from modules.gpt_loss_module import GPTSIDLossModule +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor + +from .attention_mask import ( + padded_causal_mask_with_optional_bos, + padded_target_aware_causal_mask, +) + + +def _padding_to_dense_and_transpose( + jagged_input_hidden_states: torch.Tensor, + input_offsets: torch.Tensor, + input_max_seqlen: int, +) -> torch.Tensor: + """ + Padding the jagged input hidden states to dense. + input is Batch major, output is Sequence major. + """ + batch_size = input_offsets.size(0) - 1 + assert ( + jagged_input_hidden_states.dim() == 2 + ), "jagged input hidden states should be 2D" + + padded_hidden_states = ( + torch.ops.fbgemm.jagged_to_padded_dense( + values=jagged_input_hidden_states, + offsets=[input_offsets], + max_lengths=[input_max_seqlen], + padding_value=0.0, + ) + .view(batch_size, input_max_seqlen, -1) + .transpose(1, 0) + ) # [S, B, D] + return padded_hidden_states + + +def _transpose_dense_to_jagged( + dense_hidden_states: torch.Tensor, + input_offsets: torch.Tensor, + input_max_seqlen: int, +) -> torch.Tensor: + """ + Convert the dense hidden states to jagged. + input is Sequence major, output is Batch major. + """ + + assert dense_hidden_states.dim() == 3, "dense hidden states should be 3D" + jagged_hidden_states = torch.ops.fbgemm.dense_to_jagged( + dense_hidden_states.transpose(1, 0), # [S, B, D] -> [B, S, D] + [input_offsets], + )[0] + return jagged_hidden_states + + +class SIDGRDecoder(MegatronModule): + """ + Don't support PP currently. Does not inclu de embedding + """ + + def __init__( + self, + decoder_config: TransformerConfig, # decoder config + transformer_decoder_layer_spec: ModuleSpec, + position_embedding_type: Literal[ + "learned_absolute", "rope", "relative" + ] = "learned_absolute", + relative_attention_num_buckets: int = 32, + relative_attention_max_distance: int = 128, + ): + super().__init__(config=decoder_config) + + self.config: TransformerConfig = decoder_config + + self.transformer_decoder_layer_spec: ModuleSpec = transformer_decoder_layer_spec + # TODO, add position encoder + self.model_type = ModelType.encoder_or_decoder + self.position_embedding_type = position_embedding_type + self.decoder_relative_pos_emb = RelativePositionEmbedding( + bidirectional=False, + init_method=self.config.init_method, + num_attention_heads=self.config.num_attention_heads, + relative_attention_num_buckets=relative_attention_num_buckets, + relative_attention_max_distance=relative_attention_max_distance, + ) + self.decoder = TransformerBlock( + config=self.config, + spec=self.transformer_decoder_layer_spec, + ) + + def forward( + self, + hidden_states, + attention_mask: Optional[ + torch.Tensor + ] = None, # decoder attention mask, always causal + *, + packed_seq_params: Optional[PackedSeqParams] = None, + **kwargs: Any, + ) -> torch.Tensor: + attention_bias = None + # if self.position_embedding_type == 'relative': + # # attention bias is supported by cudnn, but not fa. + # # TODO@junzhang add jagged support once we have attention kernels + # query_seq_length = input_max_seqlen + # key_seq_length = query_seq_length + # attention_bias = self.decoder_relative_pos_emb(query_seq_length, key_seq_length) + output = self.decoder( + hidden_states=hidden_states, # query + attention_mask=attention_mask, # attention mask + packed_seq_params=packed_seq_params, # query and kv seqlens + attention_bias=attention_bias, + **kwargs, + ) + return output + + +class SIDGRModel(MegatronModule): + """ + Don't support PP currently. + """ + + def __init__( + self, + decoder_config: TransformerConfig, # decoder config + codebook_embedding_config: ShardedEmbeddingConfig, # all codebooks share the same embedding + codebook_sizes: List[int], + num_hierarchies: int, + transformer_decoder_layer_spec: ModuleSpec, + position_embedding_type: Literal[ + "learned_absolute", "rope", "relative" + ] = "relative", + user_embedding_config: Optional[ShardedEmbeddingConfig] = None, + relative_attention_num_buckets: int = 32, + relative_attention_max_distance: int = 128, + should_add_sep_token: bool = True, + top_k_for_generation: int = 10, # this is used for eval + eval_metrics: Tuple[str, ...] = (), # this is used for eval + share_lm_head_across_hierarchies: bool = True, + ): + super(SIDGRModel, self).__init__(config=decoder_config) + assert ( + position_embedding_type == "relative" + ), "only relative position embedding is supported" + # TODO, use different embedding dim??? + self.embedding_dim = decoder_config.hidden_size + self.codebook_size = codebook_sizes[0] + self.add_bos_to_history_for_training = ( + decoder_config.bos_token_mode & BOSMode.HISTORY + ) != 0 + + assert all( + size == self.codebook_size for size in codebook_sizes + ), "all codebook sizes should be the same" + self._num_hierarchies = num_hierarchies + self._codebooks_collection = ShardedEmbedding( + [codebook_embedding_config] + ) # codebooks can be fused into single table + self._user_embedding_collection = ( + ShardedEmbedding([user_embedding_config]) + if user_embedding_config is not None + else None + ) # user embedding can be fused into single table + self.decoder = SIDGRDecoder( + decoder_config, + transformer_decoder_layer_spec, + position_embedding_type="relative", + ) + self.codebook_sizes = codebook_sizes + assert codebook_embedding_config.vocab_size >= sum( + codebook_sizes + ), "codebook size should be greater than the sum of codebook sizes" + assert ( + len(codebook_sizes) == num_hierarchies + ), "number of codebook sizes should match the number of hierarchies" + # bos_token used to prompt the decoder to generate the first token + # this is duplicated across dp+cp+tp ranks. (DP+CP) be broadcasted, TP same seed. + self.bos_token = torch.nn.Parameter( + torch.randn(1, self.embedding_dim), requires_grad=True + ) + # sep_token used to separate between different items + self.sep_token = ( + torch.nn.Parameter(torch.randn(1, self.embedding_dim), requires_grad=True) + if should_add_sep_token + else None + ) + + self.share_lm_head_across_hierarchies = share_lm_head_across_hierarchies + # output projection for the decoder to project the hidden state to the vocabulary space + # TODO@junzhang, TEColumnParallelLinear does not support gather_output=True + if not share_lm_head_across_hierarchies: + # TODO, combine into single grouped linear layer! + self._decoder_mlp = torch.nn.ModuleList( + [ + TEColumnParallelLinear( + input_size=self.embedding_dim, + output_size=codebook_size, + init_method=self.config.init_method, + config=self.config, + bias=False, + gather_output=False, + skip_bias_add=True, + is_expert=False, + ) + for codebook_size in self.codebook_sizes + ] + ) + else: + self._decoder_mlp = TEColumnParallelLinear( + input_size=self.embedding_dim, + output_size=sum(self.codebook_sizes), + init_method=self.config.init_method, + config=self.config, + bias=False, + gather_output=False, + skip_bias_add=True, + is_expert=False, + ) + + self.loss_module = GPTSIDLossModule( + reduction="none", + ) + + self._training_dtype = ( + torch.float16 + if decoder_config.fp16 + else (torch.bfloat16 if decoder_config.bf16 else torch.float32) + ) + for metric_spec in eval_metrics: + metric_name, top_k = metric_spec.split("@") + assert metric_name.lower() in [ + "ndcg", + "recall", + "hitrate", + ], "invalid metric name" + assert ( + int(top_k) <= top_k_for_generation + ), "top_k for evaluation should be less than top_k for generation" + # below are used for eval + self.top_k_for_generation = top_k_for_generation # beam search width. + + # below comments are reserved for multiple evaluators and debugging purpose + # _evaluators = {} + # for i in range(1, num_hierarchies + 1): + # _evaluators[f"eval_hierarchy_{i}"] = SIDRetrievalEvaluator(eval_metrics, i) + # self.evaluator = MultipleEvaluatorWrapper(_evaluators) + + self.evaluator = SIDRetrievalEvaluator(eval_metrics, num_hierarchies) + self.beam_search = BeamSearch( + beam_width=top_k_for_generation, + num_hierarchies=num_hierarchies, + codebook_sizes=codebook_sizes, + record_history=True, # for debugging purpose + ) + + def bfloat16(self): + """ + Convert the model to use bfloat16 precision. Only affects the decoder & mlp module. + + """ + self.decoder.bfloat16() + self._decoder_mlp.bfloat16() + self.bos_token.data = self.bos_token.data.bfloat16() + return self + + def half(self): + """ + Convert the model to use half precision. Only affects the decoder & mlp module. + + """ + self.decoder.half() + self._decoder_mlp.half() + self.bos_token.data = self.bos_token.data.half() + return self + + # TODO + def _inject_sep_token_between_sids( + self, + id_embeddings: torch.Tensor, + attention_mask: torch.Tensor, + sep_token: torch.Tensor, + num_hierarchies: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + return id_embeddings, attention_mask + + def _concat_jagged( + self, + jagged_embeddings: List[torch.Tensor], + jagged_offsets: List[torch.Tensor], + jagged_max_seqlens: List[int], + ) -> Tuple[torch.Tensor, torch.Tensor, int]: + assert ( + len(jagged_embeddings) == len(jagged_offsets) == len(jagged_max_seqlens) + ), "all jagged tensors should have the same length" + if len(jagged_embeddings) == 1: + return jagged_embeddings[0], jagged_offsets[0], jagged_max_seqlens[0] + max_seqlen_concat = sum(jagged_max_seqlens) + + cated_hidden_states, cated_seqlens = jagged_2D_tensor_concat( + jagged_embeddings, + jagged_offsets, + jagged_max_seqlens, + ) + cated_offsets = length_to_complete_offsets(cated_seqlens) + return cated_hidden_states, cated_offsets, max_seqlen_concat + + def _prepare_embeddings( + self, + batch: GPTSIDBatch, + add_bos_to_history: bool = False, + is_generation: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, int]: + """ + input has 3 possible cases: + generation: [history,bos] + loss on candidate:[history, bos, candidate] + loss on history and candidate:[history_with_bos_interleaved, bos, candidate]; but note that candidate might be empty. + """ + history_feature_name = batch.history_feature_name + candidate_feature_name = batch.candidate_feature_name + history_features = batch.features[history_feature_name] + max_seqlen_history = batch.feature_to_max_seqlen[history_feature_name] + max_seqlen_candidate = batch.feature_to_max_seqlen[candidate_feature_name] + actual_batch_size = batch.actual_batch_size + history_offsets = history_features.offsets() + if is_generation: + assert ( + not add_bos_to_history + ), "No need to add bos to history for generation" + # 1. embedding lookup + embeddings: Dict[str, JaggedTensor] = self._codebooks_collection(batch.features) + # TODO, remove the assertion + assert all( + feature_name in embeddings.keys() for feature_name in batch.features.keys() + ), "all embedding feature names should be valid" + + history_embeddings = ( + embeddings[history_feature_name].values().to(self._training_dtype) + ) + assert ( + self._num_hierarchies == batch._num_hierarchies + ), "number of hierarchies must match" + + jagged_embeddings = [] + jagged_offsets = [] + jagged_max_seqlens = [] + # 2. if add_bos_to_history, we insert bos token after each item (except the last one) + if add_bos_to_history: + # each item is a tuple of sid, and we need to insert bos token after each item (except the last one). + # [[item0, item1, item2, ...], [item3, item4, item5, ...], ...] -> + # [[{item0| bos, item1| bos, item2|...| bos, itemN}], [{item3| bos, item4| bos, item5|...| bos, itemM}]] + # we use cat to implement this. + history_embeddings = history_embeddings.view( + -1, self._num_hierarchies, self.embedding_dim + ) + bos_token = ( + self.bos_token.view(1, 1, -1) + .expand_as(history_embeddings)[:, :1, :] + .to(self._training_dtype) + ) + history_embeddings = torch.cat([history_embeddings, bos_token], dim=1).view( + -1, self.embedding_dim + ) + history_offsets = history_offsets // self._num_hierarchies + history_offsets + max_seqlen_history = ( + max_seqlen_history + max_seqlen_history // self._num_hierarchies + ) + # remove the last bos token of each sequence + last_bos_offsets = torch.arange( + history_offsets.size(0), + device=history_offsets.device, + dtype=history_offsets.dtype, + ).clamp(max=batch.actual_batch_size) + history_embeddings, _ = triton_split_2D_jagged( + history_embeddings, + max_seq_len=max_seqlen_history, + offsets_a=history_offsets - last_bos_offsets, + offsets_b=last_bos_offsets, + ) + max_seqlen_history -= 1 + history_offsets -= last_bos_offsets + jagged_embeddings.append(history_embeddings) + jagged_offsets.append(history_offsets) + jagged_max_seqlens.append(max_seqlen_history) + if is_generation or max_seqlen_candidate > 0: + # when is_generation, we need to append bos + # when include_candidate, we need to append [bos, candidate] + + # [[item0, item1, item2, ...], [item3, item4, item5, ...], ...] -> + # [[{item0, item1, item2, ..., itemN} | {bos}], [{item3, item4, item5, ..., itemM} | {bos}]] + # the last bos of each sequence is retained for later decoding. + # we use jagged concat + candidate_bos_offsets = torch.arange( + 0, + batch.batch_size + 1, + device=history_offsets.device, + dtype=history_offsets.dtype, + ).clamp(max=actual_batch_size) + bos_token = ( + self.bos_token.repeat(actual_batch_size, 1) + .contiguous() + .to(self._training_dtype) + ) # seqlens * num_hierarchies + jagged_embeddings.append(bos_token) + jagged_offsets.append(candidate_bos_offsets) + jagged_max_seqlens.append(1) + + # For generation, we skip this step + # 3. append candidate. + if not is_generation and max_seqlen_candidate > 0: + # [[{item0| bos, item1| bos, item2|...| bos, itemN}, {bos, candidate0}], [{item3| bos, item4| bos, item5|...| bos, itemM}, {bos, candidate1}]] + candidate_feature_name = batch.candidate_feature_name + jagged_embeddings.append( + embeddings[candidate_feature_name].values().to(self._training_dtype) + ) + jagged_offsets.append(batch.features[candidate_feature_name].offsets()) + jagged_max_seqlens.append( + batch.feature_to_max_seqlen[candidate_feature_name] + ) + ( + input_hidden_states, + input_offsets, + input_max_seqlen, + ) = self._concat_jagged( + jagged_embeddings, + jagged_offsets, + jagged_max_seqlens, + ) + + return input_hidden_states, input_offsets, input_max_seqlen + + def _postprocess_output( + self, + jagged_output_hidden_states: torch.Tensor, + input_max_seqlen: int, + input_offsets: torch.Tensor, + actual_batch_size: int, + history_offsets: torch.Tensor, + output_hierarchies: int, + add_bos_to_history: bool = False, + ) -> torch.Tensor: + """ + input has 2 possible cases: + loss on candidate:[history, bos, candidate] + loss on history and candidate:[history_with_bos_interleaved, bos, candidate] + + but note that candidate might be empty. + """ + # split history, candidate, note that we append a bos token, + # history are dropped. + # [[{item0| bos, item1| bos, item2|...| bos, itemN}, {bos, candidate0}], [{item3| bos, item4| bos, item5|...| bos, itemM}, {bos,candidate1}]] or + # [[{item0,item1,item2... itemN}, {bos, candidate0}], [{item3, item4, item5... itemM}, {bos,candidate1}]] + prefix_offsets_to_remove = ( + torch.arange( + history_offsets.size(0), + device=history_offsets.device, + dtype=history_offsets.dtype, + ).clamp(max=actual_batch_size) + * self._num_hierarchies + if add_bos_to_history + else history_offsets + ) + + # [bos, s0,s1,s2(dropped), bos,s3,s4,s5(dropped), bos,s6,s7,s8(dropped), ... bos,c_n, c_n+1, c_n+2(dropped)] + _, bos_and_candidate_hidden_states = triton_split_2D_jagged( + jagged_output_hidden_states, + max_seq_len=input_max_seqlen, + offsets_a=prefix_offsets_to_remove, + offsets_b=input_offsets - prefix_offsets_to_remove, + ) + candidate_hidden_states = bos_and_candidate_hidden_states.view( + -1, self._num_hierarchies + 1, self.embedding_dim + )[:, :output_hierarchies, :] + return candidate_hidden_states + + def decoder_step( + self, + input_hidden_states: torch.Tensor, + input_offsets: torch.Tensor, + input_max_seqlen: int, + attention_mask: Optional[torch.Tensor] = None, + padding_to_dense: bool = True, + add_bos_to_history: bool = False, + ) -> torch.Tensor: + """ + Input and Output are both jagged. + attention_mask is used only when padding_to_dense is True. + When attention mask is None, we will construct a causal attention mask if padding_to_dense is True. + + We now only support dense input. + """ + if add_bos_to_history: + assert ( + attention_mask is None + ), "attention mask should be None when adding bos to history" + # TODO, remove the padding. + input_offsets[-1].item() + if padding_to_dense: + decoder_input_hidden_states = _padding_to_dense_and_transpose( + input_hidden_states, + input_offsets, + input_max_seqlen, + ) + packed_seq_params = None + if attention_mask is None: + attention_mask = padded_causal_mask_with_optional_bos( + input_offsets, + input_max_seqlen, + add_bos_to_history=add_bos_to_history, + bos_interval=self._num_hierarchies, + ) + else: + # THD still needs batch dimension + # we need to unsqueeze the hidden states to [T, 1, hidden_size] and unsqueeze back after decoder + assert input_hidden_states.dim() == 2, "input_hidden_states should be 2D" + decoder_input_hidden_states = input_hidden_states.unsqueeze(1) + attention_mask = None + packed_seq_params = to_packed_seq_params( + input_offsets, + input_max_seqlen, + ) + decoder_output_hidden_states = self.decoder( + hidden_states=decoder_input_hidden_states, # input_hidden_states, + attention_mask=attention_mask, + packed_seq_params=packed_seq_params, # we now enforce arbitrary attention mask + dense padding + ) + + if padding_to_dense: + output_hidden_states = _transpose_dense_to_jagged( + decoder_output_hidden_states, + input_offsets, + input_max_seqlen, + ) + else: + # remove batch dim if THD + output_hidden_states = decoder_output_hidden_states.squeeze(1) + return output_hidden_states + + def forward( + self, + batch: GPTSIDBatch, + ) -> torch.Tensor: + # 1. prepare embeddings: embedding lookup + history, bos and candidate concat + ( + input_hidden_states, + input_offsets, + input_max_seqlen, + ) = self._prepare_embeddings( + batch, + add_bos_to_history=self.add_bos_to_history_for_training, + is_generation=False, + ) + history_offsets = batch.features[batch.history_feature_name].offsets() + + # 2. decoder step + jagged_output_hidden_states = self.decoder_step( + input_hidden_states, + input_offsets, + input_max_seqlen, + attention_mask=None, + add_bos_to_history=self.add_bos_to_history_for_training, + ) + # 3. postprocess: only keep the candidate hidden states + candidate_hidden_states = self._postprocess_output( + jagged_output_hidden_states, + input_max_seqlen, + input_offsets, + batch.actual_batch_size, + history_offsets, + batch._num_hierarchies, + add_bos_to_history=self.add_bos_to_history_for_training, + ) + losses_per_hierarchy = [] + logits_per_hierarchy = [] + merged_labels = batch.labels.view(-1, batch._num_hierarchies) + # 4. output linear projection & loss + # TODO, merge into single grouped linear layer + for hierarchy_idx in range(batch._num_hierarchies): + # TODO: remove this for debugging purpose + mlp = ( + self._decoder_mlp[hierarchy_idx] + if not self.share_lm_head_across_hierarchies + else self._decoder_mlp + ) + tuple_or_tensor = mlp(candidate_hidden_states[:, hierarchy_idx, :]) + candidate_hierarchy_logits = ( + tuple_or_tensor[0] + if isinstance(tuple_or_tensor, tuple) + else tuple_or_tensor + ) + losses_per_hierarchy.append( + self.loss_module( + candidate_hierarchy_logits.float(), merged_labels[:, hierarchy_idx] + ) + ) # loss needs to be float for + logits_per_hierarchy.append(candidate_hierarchy_logits) + # (T, num_hierarchies) + merged_losses = torch.stack(losses_per_hierarchy, dim=1).view(-1) + merged_logits = torch.stack(logits_per_hierarchy, dim=1).view( + -1, self.codebook_size + ) + return merged_losses, merged_logits + + @torch.no_grad + def generate(self, batch: GPTSIDBatch) -> torch.Tensor: + """ + Generate the output sids for the given batch. The generation will autogressively generate the output sids with a constrained fixed-width beam search strategy. + Args: + batch (GPTSIDBatch): The batch of data. + Returns: + torch.Tensor: The generated sids. + """ + + attention_mask: Optional[torch.Tensor] = None + # 0. prepare history and bos embeddings. Note that we do not append bos to history. + ( + history_embeddings, + input_offsets, + input_max_seqlen, + ) = self._prepare_embeddings( + batch, add_bos_to_history=False, is_generation=True + ) + batch_size = batch.actual_batch_size + input_offsets = input_offsets[: batch_size + 1] + topk_prev_step = 1 + self.beam_search.reset() + for i in range(self._num_hierarchies): + generated_sids = self.beam_search.get_sids() + # 1. prepare embeddings: [concat history, generated sids] + if generated_sids is not None: + # topk might be not always equal to the beam width because we have validation check. + batch_size, topk_prev_step, candidate_length = generated_sids.shape + assert ( + candidate_length == i + ), "current step should match the hierarchy index" + + # we must append hist. This is the defect of torchrec. Considering using torch.nn.Embedding + generated_sids_kjt = KeyedJaggedTensor.from_lengths_sync( + keys=[ + batch.candidate_feature_name, + batch.history_feature_name, + ], + values=generated_sids.view(-1), + lengths=torch.cat( + [ + torch.full( + (batch_size,), + topk_prev_step * candidate_length, + device=generated_sids.device, + dtype=torch.long, + ), + torch.zeros( + (batch_size,), + device=generated_sids.device, + dtype=torch.long, + ), + ] + ), + ) + generated_embeddings = ( + self._codebooks_collection(generated_sids_kjt)[ + batch.candidate_feature_name + ] + .values() + .to(self._training_dtype) + ) + candidate_offsets = generated_sids_kjt[ + batch.candidate_feature_name + ].offsets() + # Jagged concat! + ( + cated_hidden_states, + cated_offsets, + cated_max_seqlen, + ) = self._concat_jagged( + [history_embeddings, generated_embeddings], + [input_offsets, candidate_offsets], + [input_max_seqlen, topk_prev_step * candidate_length], + ) + else: + # when we are at the first step, we do not have any generated sids and only bos token appended to the input. + candidate_length = 0 + cated_hidden_states = history_embeddings + cated_offsets = input_offsets + cated_max_seqlen = input_max_seqlen + + # for first step, a single bos token for each sequence + candidate_offsets = torch.arange( + 0, + batch_size + 1, + device=input_offsets.device, + dtype=input_offsets.dtype, + ) + + # 2. prepare the attention mask + attention_mask = padded_target_aware_causal_mask( + torch.diff(input_offsets), + input_max_seqlen, + 0 if i == 0 else topk_prev_step, + candidate_length, + ) + # 3. we need a decoder step with the concatenated hidden states and offsets. Note that we do not add bos to history for generation. + jagged_output_hidden_states = self.decoder_step( + cated_hidden_states, + cated_offsets, + cated_max_seqlen, + attention_mask=attention_mask, + padding_to_dense=True, + add_bos_to_history=False, + ) + # remove history[batchsize * topk_last_step * max(1,i), embedding_dim] + _, candidate_hidden_states = triton_split_2D_jagged( + jagged_output_hidden_states, + max_seq_len=cated_max_seqlen, + offsets_a=cated_offsets - candidate_offsets, + offsets_b=candidate_offsets, + ) + # 4. calculate the probs for the current step + candidate_hidden_states = candidate_hidden_states.view( + batch_size, topk_prev_step, -1, self.embedding_dim + )[:, :, -1, :] + mlp = ( + self._decoder_mlp[i] + if not self.share_lm_head_across_hierarchies + else self._decoder_mlp + ) + tuple_or_tensor: Union[ + Tuple[torch.Tensor, torch.Tensor], torch.Tensor + ] = mlp(candidate_hidden_states) + # [batch_size, topk_last_step, current_codebook_size] + candidates_logits = ( + tuple_or_tensor[0] + if isinstance(tuple_or_tensor, tuple) + else tuple_or_tensor + ) + probs_this_step: torch.Tensor = torch.nn.functional.log_softmax( + candidates_logits.float(), dim=-1 + ) + # 5. filter the topk candidates, update the generated_sids and log_probs for the next step + self.beam_search.propagate(probs_this_step) + # only for debugging purpose + generated_sids = self.beam_search.get_sids() + log_probs = self.beam_search.get_log_probs() + return generated_sids, log_probs diff --git a/examples/sid_gr/model/mcore_model_specs.py b/examples/sid_gr/model/mcore_model_specs.py new file mode 100644 index 000000000..d5a57392a --- /dev/null +++ b/examples/sid_gr/model/mcore_model_specs.py @@ -0,0 +1,504 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +from typing import Optional, Union + +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.dot_product_attention import DotProductAttention +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.multi_latent_attention import ( + MLASelfAttention, + MLASelfAttentionSubmodules, +) +from megatron.core.transformer.multi_token_prediction import ( + MultiTokenPredictionBlockSubmodules, + get_mtp_layer_offset, + get_mtp_layer_spec, + get_mtp_num_layers_to_build, +) +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.torch_norm import L2Norm +from megatron.core.transformer.transformer_block import ( + TransformerBlockSubmodules, + get_num_layers_to_build, +) +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import ( + TransformerLayer, + TransformerLayerSubmodules, + get_transformer_layer_offset, +) +from megatron.core.utils import is_te_min_version + +try: + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TENorm, + TERowParallelLinear, + ) + + HAVE_TE = True +except ImportError: + HAVE_TE = False + +from megatron.core.transformer.torch_norm import WrappedTorchNorm + +try: + import apex # pylint: disable=unused-import + from megatron.core.fusions.fused_layer_norm import FusedLayerNorm + + HAVE_APEX = True + LNImpl = FusedLayerNorm +except ImportError: + import warnings + + from megatron.core.transformer.torch_norm import WrappedTorchNorm + + warnings.warn("Apex is not installed. Falling back to Torch Norm") + LNImpl = WrappedTorchNorm + + +def _get_gpt_layer_with_transformer_engine_spec( + num_experts: Optional[int] = None, + moe_grouped_gemm: Optional[bool] = False, + qk_layernorm: Optional[bool] = False, + multi_latent_attention: Optional[bool] = False, + fp8: Optional[str] = None, # pylint: disable=unused-arguments + moe_use_legacy_grouped_gemm: Optional[bool] = False, + qk_l2_norm: Optional[bool] = False, +) -> ModuleSpec: + """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). + + + Args: + num_experts (int, optional): Number of experts. Defaults to None. + moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False. + qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False. + fp8 (str, optional): Deprecated. For temporary Nemo compatibility. + moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP. + Defaults to False. + qk_l2_norm (bool, optional): To use l2 norm for queries/keys. Defaults to False. + + Returns: + ModuleSpec: Module specification with TE modules + """ + if fp8 is not None: + warnings.warn( + 'The fp8 argument in "_get_gpt_layer_with_transformer_engine_spec" has been deprecated' + " and will be removed soon. Please update your code accordingly." + ) + + mlp = _get_mlp_module_spec( + use_te=True, + num_experts=num_experts, + moe_grouped_gemm=moe_grouped_gemm, + moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, + ) + + if multi_latent_attention: + assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA." + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=TENorm, + self_attention=ModuleSpec( + module=MLASelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=MLASelfAttentionSubmodules( + linear_q_proj=TEColumnParallelLinear, + linear_q_down_proj=TEColumnParallelLinear, + linear_q_up_proj=( + TELayerNormColumnParallelLinear + if qk_layernorm + else TEColumnParallelLinear + ), + linear_kv_down_proj=TEColumnParallelLinear, + linear_kv_up_proj=( + TELayerNormColumnParallelLinear + if qk_layernorm + else TEColumnParallelLinear + ), + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=IdentityOp, + kv_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=TENorm if num_experts else IdentityOp, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + ), + ) + else: + # TENorm significantly harms convergence when used + # for QKLayerNorm if TE Version < 1.9; + # we instead use the Apex implementation. + qk_norm = TENorm if is_te_min_version("1.9.0") else FusedLayerNorm + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=TELayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=( + L2Norm + if qk_l2_norm + else (qk_norm if qk_layernorm else IdentityOp) + ), + k_layernorm=( + L2Norm + if qk_l2_norm + else (qk_norm if qk_layernorm else IdentityOp) + ), + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=TENorm if num_experts else IdentityOp, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + ), + ) + + +def _get_gpt_layer_local_spec( + num_experts: Optional[int] = None, + moe_grouped_gemm: Optional[bool] = False, + qk_layernorm: Optional[bool] = False, + multi_latent_attention: Optional[bool] = False, + fp8: Optional[str] = None, # pylint: disable=unused-arguments + moe_use_legacy_grouped_gemm: Optional[bool] = False, + normalization: Optional[str] = None, + qk_l2_norm: Optional[bool] = False, + arbitrary_attention_mask: Optional[bool] = False, +) -> ModuleSpec: + """Use this spec for an implementation using only modules in Megatron-Core. + + + Args: + num_experts (int, optional): Number of experts. Defaults to None. + moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False. + qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False. + fp8 (str, optional): Deprecated. For temporary Nemo compatibility. + moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP. + Defaults to False. + qk_l2_norm (bool, optional): To use l2 norm for queries/keys. Defaults to False. + + Returns: + ModuleSpec: Module specification with Megatron-Core modules + """ + if arbitrary_attention_mask: + attention_mask_type = AttnMaskType.arbitrary + else: + attention_mask_type = AttnMaskType.causal + + # Adjust for RMS norm. + if normalization == "RMSNorm": + global LNImpl + LNImpl = WrappedTorchNorm + if fp8 is not None: + warnings.warn( + 'The fp8 argument in "_get_gpt_layer_local_spec" has been deprecated' + " and will be removed soon. Please update your code accordingly." + ) + + mlp = _get_mlp_module_spec( + use_te=False, + num_experts=num_experts, + moe_grouped_gemm=moe_grouped_gemm, + moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, + ) + + if multi_latent_attention: + assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA." + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=LNImpl, + self_attention=ModuleSpec( + module=MLASelfAttention, + params={"attn_mask_type": attention_mask_type}, + submodules=MLASelfAttentionSubmodules( + linear_q_proj=ColumnParallelLinear, + linear_q_down_proj=ColumnParallelLinear, + linear_q_up_proj=ColumnParallelLinear, + linear_kv_down_proj=ColumnParallelLinear, + linear_kv_up_proj=ColumnParallelLinear, + core_attention=DotProductAttention, + linear_proj=RowParallelLinear, + q_layernorm=LNImpl if qk_layernorm else IdentityOp, + kv_layernorm=LNImpl if qk_layernorm else IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=LNImpl, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + ), + ) + else: + # this is the selected path for sid_gr for now. + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=LNImpl, + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": attention_mask_type}, + submodules=SelfAttentionSubmodules( + linear_qkv=ColumnParallelLinear, + core_attention=DotProductAttention, + linear_proj=RowParallelLinear, + q_layernorm=( + L2Norm + if qk_l2_norm + else (LNImpl if qk_layernorm else IdentityOp) + ), + k_layernorm=( + L2Norm + if qk_l2_norm + else (LNImpl if qk_layernorm else IdentityOp) + ), + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=LNImpl, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + sharded_state_dict_keys_map={ + "input_layernorm.": "self_attention.linear_qkv.layer_norm_", + "pre_mlp_layernorm.": "mlp.linear_fc1.layer_norm_", + }, + ), + ) + + +def __get_mlp_module_spec( + use_te: Optional[bool] = True, + num_experts: Optional[int] = None, + moe_grouped_gemm: Optional[bool] = False, + fp8: Optional[str] = None, # pylint: disable=unused-arguments + moe_use_legacy_grouped_gemm: Optional[bool] = False, +): + warnings.warn( + """This private function is on a deprecation track. Please switch to `_get_mlp_module_spec` + since it will be removed in a future release.""" + ) + + return _get_mlp_module_spec( + use_te=use_te, + num_experts=num_experts, + moe_grouped_gemm=moe_grouped_gemm, + fp8=fp8, + moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, + ) + + +def _get_mlp_module_spec( + use_te: Optional[bool] = True, + num_experts: Optional[int] = None, + moe_grouped_gemm: Optional[bool] = False, + fp8: Optional[str] = None, # pylint: disable=unused-arguments + moe_use_legacy_grouped_gemm: Optional[bool] = False, +) -> ModuleSpec: + """Helper function to get module spec for MLP/MoE""" + if fp8 is not None: + warnings.warn( + 'The fp8 argument in "__get_mlp_module_spec" has been deprecated' + " and will be removed soon. Please update your code accordingly." + ) + + if num_experts is None: + # Dense MLP w/ or w/o TE modules. + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear + if use_te + else ColumnParallelLinear, + linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, + ), + ) + else: + # Mixture of experts with modules in megatron core. + return get_moe_module_spec( + use_te=use_te, + num_experts=num_experts, + moe_grouped_gemm=moe_grouped_gemm, + moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, + ) + + +def get_gpt_decoder_block_spec( + config: TransformerConfig, + use_transformer_engine: bool, + arbitrary_attention_mask: Optional[bool] = False, + normalization: Optional[str] = None, + qk_l2_norm: Optional[bool] = False, +) -> TransformerBlockSubmodules: + """GPT block spec.""" + if use_transformer_engine: + layer_norm_impl = TENorm + else: + # adjust for rmsnorm + if normalization == "RMSNorm": + layer_norm_impl = WrappedTorchNorm + else: + layer_norm_impl = LNImpl + + if arbitrary_attention_mask: + assert ( + not use_transformer_engine + ), "arbitrary attention mask is only supported with Megatron-Core modules" + + # Layer specs. + dense_layer_spec = ( + _get_gpt_layer_with_transformer_engine_spec( + num_experts=None, + moe_grouped_gemm=False, + qk_layernorm=config.qk_layernorm, + multi_latent_attention=config.multi_latent_attention, + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, + qk_l2_norm=qk_l2_norm, + ) + if use_transformer_engine + else _get_gpt_layer_local_spec( + num_experts=None, + moe_grouped_gemm=False, + qk_layernorm=config.qk_layernorm, + multi_latent_attention=config.multi_latent_attention, + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, + normalization=normalization, + qk_l2_norm=qk_l2_norm, + arbitrary_attention_mask=arbitrary_attention_mask, + ) + ) + moe_layer_spec = ( + _get_gpt_layer_with_transformer_engine_spec( + num_experts=config.num_moe_experts, + moe_grouped_gemm=config.moe_grouped_gemm, + qk_layernorm=config.qk_layernorm, + multi_latent_attention=config.multi_latent_attention, + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, + qk_l2_norm=qk_l2_norm, + ) + if use_transformer_engine + else _get_gpt_layer_local_spec( + num_experts=config.num_moe_experts, + moe_grouped_gemm=config.moe_grouped_gemm, + qk_layernorm=config.qk_layernorm, + multi_latent_attention=config.multi_latent_attention, + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, + normalization=normalization, + qk_l2_norm=qk_l2_norm, + arbitrary_attention_mask=arbitrary_attention_mask, + ) + ) + + # Parse config.moe_layer_freq to determine the pattern of expert/dense layers. + # 0 stands for dense layers, 1 stands for expert layers. + # For integer N: Creates a pattern with one expert layer every N layers. + # For string pattern: Evaluates the str directly (e.g. "[1,0,1]" for alternating expert/dense). + if isinstance(config.moe_layer_freq, int): + moe_layer_pattern = [ + 1 if (i % config.moe_layer_freq == 0) else 0 + for i in range(config.num_layers) + ] + elif isinstance(config.moe_layer_freq, list): + moe_layer_pattern = config.moe_layer_freq + assert len(moe_layer_pattern) == config.num_layers, ( + f"Invalid length of moe_layer_pattern: {len(moe_layer_pattern)}, " + f"expected {config.num_layers}, " + f"current moe layer pattern: {config.moe_layer_freq}" + ) + else: + raise ValueError( + f"Invalid moe_layer_freq: {type(config.moe_layer_freq)}, {config.moe_layer_freq}" + ) + + # Create the layer specs for the model. + layer_specs = [] + for layer_number in range(config.num_layers): + if moe_layer_pattern[layer_number] == 1: + layer_specs.append(moe_layer_spec) + elif moe_layer_pattern[layer_number] == 0: + layer_specs.append(dense_layer_spec) + else: + raise ValueError(f"Invalid layer pattern: {moe_layer_pattern}") + + # Slice the layer specs to only include the layers that are built in this pipeline stage. + # Note: MCore layer_number starts at 1 + offset = get_transformer_layer_offset(config) + num_layers_to_build = get_num_layers_to_build(config) + layer_specs = layer_specs[offset : offset + num_layers_to_build] + + # Block spec. + block_spec = TransformerBlockSubmodules( + layer_specs=layer_specs, layer_norm=layer_norm_impl + ) + + return block_spec + + +def get_gpt_mtp_block_spec( + config: TransformerConfig, + spec: Union[TransformerBlockSubmodules, ModuleSpec], + use_transformer_engine: bool, +) -> MultiTokenPredictionBlockSubmodules: + """GPT Multi-Token Prediction (MTP) block spec.""" + num_layers_to_build = get_mtp_num_layers_to_build(config) + if num_layers_to_build == 0: + return None + + if isinstance(spec, TransformerBlockSubmodules): + # get the spec for the last layer of decoder block + transformer_layer_spec = spec.layer_specs[-1] + elif isinstance(spec, ModuleSpec) and spec.module == TransformerLayer: + transformer_layer_spec = spec + else: + raise ValueError(f"Invalid spec: {spec}") + + mtp_layer_spec = get_mtp_layer_spec( + transformer_layer_spec=transformer_layer_spec, + use_transformer_engine=use_transformer_engine, + ) + mtp_num_layers = config.mtp_num_layers if config.mtp_num_layers else 0 + mtp_layer_specs = [mtp_layer_spec] * mtp_num_layers + + offset = get_mtp_layer_offset(config) + # split the mtp layer specs to only include the layers that are built in this pipeline stage. + mtp_layer_specs = mtp_layer_specs[offset : offset + num_layers_to_build] + if len(mtp_layer_specs) > 0: + assert ( + len(mtp_layer_specs) == config.mtp_num_layers + ), f"currently all of the mtp layers must stage in the same pipeline stage." + mtp_block_spec = MultiTokenPredictionBlockSubmodules( + layer_specs=mtp_layer_specs + ) + else: + mtp_block_spec = None + + return mtp_block_spec diff --git a/examples/sid_gr/modules/__init__.py b/examples/sid_gr/modules/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/sid_gr/modules/eval_metrics.py b/examples/sid_gr/modules/eval_metrics.py new file mode 100644 index 000000000..7a5752c27 --- /dev/null +++ b/examples/sid_gr/modules/eval_metrics.py @@ -0,0 +1,288 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from typing import Any, Dict, Tuple + +import torch +import torchmetrics + + +class BaseMeanDistributedReductionMetric(torchmetrics.Metric, ABC): + """ + Computes a metric using mean reduction (average across queries) aggregated across distributed workers. + Note that we suppose the parallelism is along the query dimension, that is, each worker only processes a subset of queries. + We allow batch_size per worker to be different. + """ + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + self.metric_values: float = 0.0 + self.num_queries: int = 0 + + @abstractmethod + def update( + self, preds: torch.Tensor, target: torch.Tensor, indexes: torch.Tensor, **kwargs + ) -> None: + """ + To be implemented by child classes: update stats from current batch. + """ + + def compute(self) -> torch.Tensor: + """ + Returns: torch.Tensor, the aggregated metric averaged over all queries in all distributed workers. + """ + metric_values_tensor = torch.tensor( + self.metric_values, device=self.device, dtype=torch.float32 + ) + num_queries_tensor = torch.tensor( + self.num_queries, device=self.device, dtype=torch.int64 + ) + if self.num_queries == 0: + return torch.tensor(0.0, device=self.device) + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.all_reduce( + metric_values_tensor, op=torch.distributed.ReduceOp.SUM + ) + torch.distributed.all_reduce( + num_queries_tensor, op=torch.distributed.ReduceOp.SUM + ) + return metric_values_tensor.sum() / num_queries_tensor.sum() + + def reset(self) -> None: + self.metric_values = 0.0 + self.num_queries = 0 + + +class DistributedRetrievalMetric(BaseMeanDistributedReductionMetric, ABC): + """ + Base for distributed retrieval ranking metrics (@K metrics), per query. + """ + + def __init__(self, top_k: int, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.top_k = top_k + + def update( + self, preds: torch.Tensor, target: torch.Tensor, indexes: torch.Tensor, **kwargs + ) -> None: + """ + Args: + preds, target: [batch_size * num_candidates], all flattened. + indexes: [batch_size * num_candidates], each element is the index of the query (0 ... batch_size-1). + """ + # Determine batch and candidate shapes by how index is structured. + # Each group of (num_candidate_per_query) consecutive elements is a query. + # We assume indexes is constructed in row-major order. + # If all queries are contiguous, then the number of zeros in indexes is the batch_size. + num_candidate_per_query = (indexes == 0).sum().item() + batch_size = indexes.numel() // num_candidate_per_query + + preds = preds.view(batch_size, num_candidate_per_query) + target = target.view(batch_size, num_candidate_per_query).int() + + metric_result = self._metric_impl(preds, target) + # sum over batch dimension + self.metric_values += metric_result.sum().item() + self.num_queries += batch_size + + @abstractmethod + def _metric_impl(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Returns tensor of shape [batch_size] + """ + + +class DistributedRetrievalNDCG(DistributedRetrievalMetric): + """ + Normalized Discounted Cumulative Gain@K metric for retrieval (per query/batch). + """ + + def _metric_impl(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + # topk predicted indices by score + topk_indices = torch.topk(preds, self.top_k, dim=1).indices + topk_true = target.gather(1, topk_indices) + + # DCG + denom = torch.log2( + torch.arange(2, self.top_k + 2, device=target.device).float() + ).unsqueeze(0) + dcg = (topk_true / denom).sum(dim=1) + + # Ideal DCG (use ideal ranking) + ideal_indices = torch.topk(target, self.top_k, dim=1).indices + ideal_dcg = (target.gather(1, ideal_indices) / denom).sum(dim=1) + + # Avoid div by zero + ndcg = dcg / torch.where(ideal_dcg == 0, torch.ones_like(ideal_dcg), ideal_dcg) + + return ndcg + + +class DistributedRetrievalRecall(DistributedRetrievalMetric): + """ + Recall@K metric for retrieval (per query/batch). + """ + + def _metric_impl(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ """ + # preds (sorted by score) [1, 2', 3, 4, 5, 6', 7', 8'] + # target (ground truth) [False, True, False, False, False, True, True, True] + + # top2 preds [1, 2'] + # top2 target [False, True] + # recall: |2'| / |1, 2'| = 1 / 2 = 0.5 + # How many relevant items are in the topk predicted items? + topk_indices = torch.topk(preds, self.top_k, dim=1).indices + topk_target = target.gather(1, topk_indices) # [batch, top_k] + num_hit_in_topk = topk_target.sum(dim=1) # [batch], total recalled samples + # for sid, total_relevant <= 1. Because the labels for each query contain single item. + total_relevant = target.sum(dim=1) + # denorm is different from standard torchmetrics. We use the min + denom = total_relevant.minimum( + torch.tensor(self.top_k, device=target.device) + ).clamp(min=1) + recall = num_hit_in_topk / denom + return recall + + +class DistributedRetrievalHitRate(DistributedRetrievalMetric): + """ + Recall@K metric for retrieval (per query/batch). + """ + + def _metric_impl(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + preds: [batchsize, num_candidates] ( for sid, the num_candidates is the beam_width) + target: [batchsize, num_candidates] ( for sid, the target is the ground truth) + """ + + # 1. get the topk result + topk_indices = torch.topk(preds, self.top_k, dim=1).indices + topk_target = target.gather(1, topk_indices) # [batch, top_k] + + # 2. check if topk results hit the ground truth + hit = topk_target.any(dim=1) # [batch] + return hit + + +_metric_str_to_object = { + "ndcg": DistributedRetrievalNDCG, + "recall": DistributedRetrievalRecall, + "hitrate": DistributedRetrievalHitRate, +} + + +class SIDRetrievalEvaluator(torch.nn.Module): + """ + Helper for evaluating retrieval metrics for semantic ID tasks. + """ + + def __init__(self, eval_metrics: Tuple[str, ...], sid_prefix_length: int = -1): + super().__init__() + self.metrics = torch.nn.ModuleDict() + for metric_spec in eval_metrics: + metric_name, top_k = metric_spec.split("@") + metric_class = _metric_str_to_object[metric_name.lower()] + self.metrics[metric_spec] = metric_class( + top_k=int(top_k), sync_on_compute=False, compute_with_cache=False + ) + self.sid_prefix_length = sid_prefix_length + + def state_dict(self): + # Metrics not checkpointed. + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]): + pass + + def forward( + self, + log_probs: torch.Tensor, + generated_ids: torch.Tensor, + labels: torch.Tensor, + **kwargs, + ): + """ + Args: + log_probs: [batch, num_candidates] + generated_ids: [batch, num_candidates, num_hierarchies] + labels: [batch, num_hierarchies] + """ + batch_size, num_candidates, num_hierarchies = generated_ids.shape + # Reshape for matching + labels = labels.view(batch_size, 1, num_hierarchies) + generated_ids = generated_ids[:, :, : self.sid_prefix_length] + labels = labels[:, :, : self.sid_prefix_length] + preds = log_probs.reshape(-1) + # Match each candidate's IDs to groundtruth: [batch, num_candidates] + matched_id_coord = torch.all(generated_ids == labels, dim=2).nonzero( + as_tuple=True + ) + target = torch.zeros( + batch_size, num_candidates, dtype=torch.bool, device=generated_ids.device + ) + + target[matched_id_coord] = True + target = target.view(-1) + # indexes is not used. Assign a dummy value. + expanded_indexes = ( + torch.arange(batch_size, device=log_probs.device) + .unsqueeze(-1) + .expand(batch_size, num_candidates) + .reshape(-1) + ) + + for metric_obj in self.metrics.values(): + metric_obj.update( + preds, + target.to(preds.device), + indexes=expanded_indexes.to(preds.device), + ) + + def compute(self): + return { + metric_name: metric.compute() + for metric_name, metric in self.metrics.items() + } + + def reset(self): + for metric in self.metrics.values(): + metric.reset() + + +class MultipleEvaluatorWrapper(torch.nn.Module): + """ + Wrapper for multiple evaluators. + """ + + def __init__(self, evaluators: Dict[str, SIDRetrievalEvaluator]): + super().__init__() + self.evaluators = torch.nn.ModuleDict(evaluators) + + def forward( + self, log_probs: torch.Tensor, generated_ids: torch.Tensor, labels: torch.Tensor + ): + for evaluator in self.evaluators.values(): + evaluator(log_probs, generated_ids, labels) + + def compute(self): + ret = {} + for evaluator_name, evaluator in self.evaluators.items(): + ret[evaluator_name] = evaluator.compute() + return ret + + def reset(self): + for evaluator in self.evaluators.values(): + evaluator.reset() diff --git a/examples/sid_gr/modules/gpt_loss_module.py b/examples/sid_gr/modules/gpt_loss_module.py new file mode 100644 index 000000000..a4ab98510 --- /dev/null +++ b/examples/sid_gr/modules/gpt_loss_module.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from commons.utils.nvtx_op import output_nvtx_hook + + +class GPTSIDLossModule(torch.nn.Module): + """ + Multi-task loss module for handling multiple loss functions. A loss head is either a + BCEWithLogitsLoss or CrossEntropyLoss. + """ + + def __init__(self, reduction="none"): + super().__init__() + self._loss_modules = torch.nn.CrossEntropyLoss(reduction=reduction) + + @output_nvtx_hook(nvtx_tag="loss computation") + def forward(self, merged_logits, labels) -> torch.Tensor: + """ + Forward pass of the GPTSIDLossModule. + + Args: + merged_logits (torch.Tensor): (N, num_tasks),The merged logits tensor. Must be 2D tensor of float dtype. + labels (torch.Tensor): (N,), The labels tensor. + + Returns: + torch.Tensor: The computed losses for each task. + """ + assert merged_logits.dim() == 2, "loss module expects 2D logit" + assert merged_logits.dtype == torch.float, "merged_logits dtype should be float" + assert ( + labels.dtype == torch.int32 or labels.dtype == torch.int64 + ), "labels dtype should be integer" + + return self._loss_modules(merged_logits, labels) diff --git a/examples/sid_gr/tests/test_beam_search.py b/examples/sid_gr/tests/test_beam_search.py new file mode 100644 index 000000000..485484388 --- /dev/null +++ b/examples/sid_gr/tests/test_beam_search.py @@ -0,0 +1,69 @@ +import pytest +import torch +from beam_search.beam_search import BeamSearch + + +@pytest.mark.parametrize("batchsize", [10, 20, 50]) +@pytest.mark.parametrize("beam_width", [10, 20, 50]) +@pytest.mark.parametrize("codebook_sizes", [[100, 100, 100]]) +def test_beam_search_sanity_check(batchsize, beam_width, codebook_sizes): + num_hierarchies = len(codebook_sizes) + beam_search = BeamSearch( + beam_width, num_hierarchies, codebook_sizes, record_history=True + ) + topk_prev_step = 1 + for i in range(num_hierarchies): + log_probs = torch.randn( + batchsize, + topk_prev_step, + codebook_sizes[i], + device=torch.cuda.current_device(), + ) + + beam_search.propagate(log_probs) + topk_prev_step = beam_width + # check the childrens' prefix should be from parent + for i in range(1, num_hierarchies): + # shape [batchsize, cur_beam, i + 1] + current_sids = beam_search.history_topk_sids[i] + # shape [batchsize, par_beam, i] + parent_sids = beam_search.history_topk_sids[i - 1] + current_sids_depth = current_sids.shape[-1] + parent_sids_depth = parent_sids.shape[-1] + assert ( + parent_sids_depth + 1 == current_sids_depth + ), "current_sids_depth should be parent_sids_depth + 1" + current_slice = current_sids[:, :, :parent_sids_depth] # [B, cur_beam, K] + parent_slice = parent_sids # [B, par_beam, K] + + # [batchsize, cur_beam, 1, K] == [batchsize, 1, par_beam, K] + is_in = current_slice.unsqueeze(2) == parent_slice.unsqueeze( + 1 + ) # [B, cur_beam, par_beam, K] + in_any_parent = is_in.any(dim=2) # [B, cur_beam, K] + assert torch.all(in_any_parent) + + +@pytest.mark.parametrize("batchsize", [10, 20, 50]) +@pytest.mark.parametrize("codebook_sizes", [[100, 100, 100]]) +def test_beam_search_top1(batchsize, codebook_sizes): + """ + top1 means no beam search, only the top1 candidate is selected. + """ + beam_width = 1 + num_hierarchies = len(codebook_sizes) + beam_search = BeamSearch(beam_width, num_hierarchies, codebook_sizes) + accu_log_probs = torch.zeros(batchsize, device=torch.cuda.current_device()) + sids = torch.empty( + batchsize, 0, device=torch.cuda.current_device(), dtype=torch.long + ) + for i in range(num_hierarchies): + log_probs = torch.randn( + batchsize, 1, codebook_sizes[i], device=torch.cuda.current_device() + ) + beam_search.propagate(log_probs) + accu_log_probs = accu_log_probs.unsqueeze(-1) + log_probs.view(batchsize, -1) + accu_log_probs, current_sids = torch.max(accu_log_probs, dim=-1) + # select the max prob candidate for each batch + sids = torch.cat([sids, current_sids.unsqueeze(-1)], dim=-1) + torch.equal(beam_search.get_sids().view(-1), sids.view(-1)) diff --git a/examples/sid_gr/tests/test_dataset.py b/examples/sid_gr/tests/test_dataset.py new file mode 100644 index 000000000..d8922328d --- /dev/null +++ b/examples/sid_gr/tests/test_dataset.py @@ -0,0 +1,209 @@ +import pytest +import torch +from commons.ops.triton_ops.triton_jagged import triton_split_2D_jagged +from datasets.disk_sequence_dataset import DiskSequenceDataset +from datasets.gpt_sid_batch import FeatureConfig, GPTSIDBatch +from tqdm import tqdm + + +@pytest.mark.parametrize("batch_size", [128, 256, 512]) +def test_batch(batch_size): + feature_configs = [ + FeatureConfig( + feature_names=[ + "hist_sid_0", + "hist_sid_1", + "hist_sid_2", + "hist_sid_3", + "timestamp", + ], + max_item_ids=[128, 128, 128, 128, 100000], + min_item_ids=[0, 0, 0, 0, 0], + max_sequence_length=128, + is_jagged=True, + ), + FeatureConfig( + feature_names=["cand_sid_0", "cand_sid_1", "cand_sid_2", "cand_sid_3"], + max_item_ids=[128, 128, 128, 128], + min_item_ids=[0, 0, 0, 0], + max_sequence_length=128, + is_jagged=True, + ), + FeatureConfig( + feature_names=[ + "contextual_0", + "contextual_1", + ], + max_item_ids=[ + 4, + 100, + ], + min_item_ids=[ + 0, + 0, + ], + max_sequence_length=4, + is_jagged=False, + ), + ] + raw_hist_sid_names = ["hist_sid_0", "hist_sid_1", "hist_sid_2", "hist_sid_3"] + raw_cand_sid_names = ["cand_sid_0", "cand_sid_1", "cand_sid_2", "cand_sid_3"] + contextual_feature_names = ["contextual_0", "contextual_1"] + batch = GPTSIDBatch.random( + batch_size=batch_size, + feature_configs=feature_configs, + raw_hist_sid_names=raw_hist_sid_names, + raw_cand_sid_names=raw_cand_sid_names, + contextual_feature_names=contextual_feature_names, + combined_history_feature_name="hist_sids", + combined_candidate_feature_name="cand_sids", + device=torch.cuda.current_device(), + ) + assert all( + hist_sid_name not in batch.features.keys() + for hist_sid_name in raw_hist_sid_names + ), "history sid feature names should not be in the batch features" + assert all( + cand_sid_name not in batch.features.keys() + for cand_sid_name in raw_cand_sid_names + ), "candidate sid feature names should not be in the batch features" + assert ( + "hist_sids" in batch.features.keys() + ), "history sids feature name should be in the batch features" + assert ( + "cand_sids" in batch.features.keys() + ), "candidate sids feature name should be in the batch features" + assert ( + batch.features["hist_sids"].lengths().numel() == batch_size + ), "history sids feature length should be 128" + assert ( + batch.features["cand_sids"].lengths().numel() == batch_size + ), "candidate sids feature length should be 128" + + +@pytest.mark.parametrize("batch_size", [128, 256, 512]) +@pytest.mark.parametrize("max_history_length", [64, 128, 256]) +@pytest.mark.parametrize("max_candidate_length", [0, 1]) +def test_disk_sequence_dataset( + batch_size, + max_history_length, + max_candidate_length, +): + num_hierarchies = 4 + disk_sequence_dataset = DiskSequenceDataset( + raw_sequence_data_path="./tmp_data/amzn/beauty/training/22363.parquet", + item_id_to_sid_mapping_tensor_path="./tmp_data/amzn/beauty/item-sid-mapping.pt", + batch_size=batch_size, + max_history_length=max_history_length, + max_candidate_length=max_candidate_length, + raw_sequence_feature_name="sequence_data", + num_hierarchies=num_hierarchies, + codebook_sizes=[256, 256, 256, 256], + rank=0, + world_size=1, + shuffle=False, + random_seed=1234, + is_train_dataset=True, + deduplicate_data_across_hierarchy=False, + deduplicate_label_across_hierarchy=False, + ) + num_batches = len(disk_sequence_dataset) + for idx, batch in enumerate( + tqdm( + disk_sequence_dataset, + total=num_batches, + desc="Testing disk sequence dataset", + ) + ): + batch = batch.to(torch.cuda.current_device()) + for key in batch.features.keys(): + assert ( + batch.features[key].lengths().numel() == batch_size + ), f"length of {key} should be {batch_size}" + if idx < len(disk_sequence_dataset) - 1 and max_candidate_length > 0: + assert ( + batch.labels.view(-1, num_hierarchies).shape[0] == batch_size + ), f"labels should be {batch_size}" + if max_candidate_length == 0: + # labels are the history sids + history_sids = ( + batch.features[batch.history_feature_name] + .values() + .view(-1, num_hierarchies) + ) + prefix_to_remove = torch.arange( + batch_size + 1, device=batch.labels.device + ).clamp(max=batch.actual_batch_size) + _, shifted_history_sids = triton_split_2D_jagged( + history_sids, + max_seq_len=max_history_length, + offsets_a=prefix_to_remove, + offsets_b=batch.features[batch.history_feature_name].offsets() + // num_hierarchies + - prefix_to_remove, + ) + labels = batch.labels.view(-1, num_hierarchies) + assert torch.all(labels == shifted_history_sids) + + if batch.actual_batch_size != batch_size: + if max_candidate_length == 1: + assert batch.labels.shape[0] == batch.actual_batch_size + else: + assert ( + batch.labels.shape[0] + == history_sids.shape[0] - batch.actual_batch_size + ) + + +def test_sid_data_loader(): + rank = 0 + world_size = 1 + from configs.sid_gin_config_args import DatasetArgs, TrainerArgs + from datasets.sid_data_loader import get_train_and_test_data_loader + + torch.distributed.init_process_group( + backend="nccl", rank=rank, world_size=world_size + ) + + dataset_args = DatasetArgs( + dataset_name="amzn/beauty", + max_history_length=128, + dataset_type_str="disk_sequence_dataset", + sequence_features_training_data_path="./tmp_data/amzn/beauty/training/22363.parquet", + sequence_features_testing_data_path="./tmp_data/amzn/beauty/testing/22363.parquet", + item_to_sid_mapping_path="./tmp_data/amzn/beauty/item-sid-mapping.pt", + shuffle=False, + num_hierarchies=4, + codebook_sizes=[256, 256, 256, 256], + ) + trainer_args = TrainerArgs( + train_batch_size=128, + eval_batch_size=128, + max_train_iters=1000, + max_eval_iters=100, + seed=1234, + ) + + train_loader, eval_loader = get_train_and_test_data_loader( + dataset_args, trainer_args + ) + for idx, batch in enumerate( + tqdm(train_loader, total=len(train_loader), desc="Testing train loader") + ): + for key in batch.features.keys(): + assert ( + batch.features[key].lengths().numel() == trainer_args.train_batch_size + ), f"length of {key} should be {dataset_args.train_batch_size}" + if idx < len(train_loader) - 1: + assert ( + batch.labels.view(-1, dataset_args.num_hierarchies).shape[0] + == trainer_args.train_batch_size + ), f"labels should be {trainer_args.train_batch_size}" + + for idx, batch in enumerate( + tqdm(eval_loader, total=len(eval_loader), desc="Testing eval loader") + ): + for key in batch.features.keys(): + assert ( + batch.features[key].lengths().numel() == trainer_args.eval_batch_size + ), f"length of {key} should be {trainer_args.eval_batch_size}" diff --git a/examples/sid_gr/tests/test_metric.py b/examples/sid_gr/tests/test_metric.py new file mode 100644 index 000000000..46205eece --- /dev/null +++ b/examples/sid_gr/tests/test_metric.py @@ -0,0 +1,61 @@ +import pytest +import torch +from modules.eval_metrics import DistributedRetrievalHitRate, DistributedRetrievalRecall +from torchmetrics.retrieval import RetrievalHitRate, RetrievalRecall + +ref_metric_dict = { + "hr": RetrievalHitRate, + "recall": RetrievalRecall, +} + +sid_metric_dict = { + "hr": DistributedRetrievalHitRate, + "recall": DistributedRetrievalRecall, +} + + +@pytest.mark.parametrize("eval_metric", ["Recall@10", "HR@10", "HR@20"]) +@pytest.mark.parametrize("batch_size", [512, 1024, 2, 1]) +@pytest.mark.parametrize("num_candidates", [100, 200, 5000]) +def test_sid_retrieval_evaluator( + eval_metric, + batch_size, + num_candidates, +): + device = torch.device("cuda") + + for i in range(10): + log_probs = torch.randn(batch_size, num_candidates, device=device) + target = torch.zeros( + batch_size, num_candidates, dtype=torch.bool, device=device + ) + # set one target to True + col_indices = torch.randint(0, num_candidates, (batch_size,), device=device) + row_indices = torch.arange(batch_size, device=device) + target[row_indices, col_indices] = True + indexes = ( + torch.arange(batch_size, device=log_probs.device) + .unsqueeze(-1) + .expand(-1, num_candidates) + ) + + metric_name, top_k = eval_metric.split("@") + metric_name = metric_name.lower() + + always_hit_batch_id = torch.randint(0, batch_size, (1,), device=device) + log_probs[always_hit_batch_id, col_indices[always_hit_batch_id]] = ( + log_probs[always_hit_batch_id].max() + 0.1 + ) + top_k = int(top_k) + ref_metric = ref_metric_dict[metric_name](top_k=top_k) + # without sync and cache + sid_metric = sid_metric_dict[metric_name]( + top_k=top_k, sync_on_compute=False, compute_with_cache=False + ).cuda() + + sid_metric(log_probs, target, indexes=indexes) + ref_metric(log_probs, target, indexes=indexes) + + sid_results = sid_metric.compute() + ref_results = ref_metric.compute() + assert torch.equal(sid_results, ref_results) diff --git a/examples/sid_gr/tests/test_model_smoke.py b/examples/sid_gr/tests/test_model_smoke.py new file mode 100644 index 000000000..841cd80f2 --- /dev/null +++ b/examples/sid_gr/tests/test_model_smoke.py @@ -0,0 +1,293 @@ +from typing import List + +import commons.utils as init +import pytest +import torch +from commons.checkpoint import get_unwrapped_module +from commons.modules.embedding import ShardedEmbeddingConfig +from commons.ops.length_to_offsets import length_to_complete_offsets +from datasets.gpt_sid_batch import FeatureConfig, GPTSIDBatch +from tests.test_utils import create_sid_gr_model_and_optimizer + + +def generate_batches( + batchsize: int, + num_batches: int, + max_history_length: int, + max_candidate_length: int, + codebook_sizes: List[int], + combined_history_feature_name: str, + combined_candidate_feature_name: str, + contextual_feature_names: List[str], +): + codebook_sizes = torch.tensor(codebook_sizes) + num_hierarchies = len(codebook_sizes) + cum_sum_codebook_size = length_to_complete_offsets(codebook_sizes) + max_item_ids = cum_sum_codebook_size[1:] + min_item_ids = cum_sum_codebook_size[:-1] + raw_hist_sid_names = [f"hist_sid_{i}" for i in range(num_hierarchies)] + raw_cand_sid_names = [f"cand_sid_{i}" for i in range(num_hierarchies)] + raw_feature_configs = [ + FeatureConfig( + feature_names=raw_hist_sid_names, + max_item_ids=max_item_ids, + min_item_ids=min_item_ids, + max_sequence_length=max_history_length, + is_jagged=True, + ), + FeatureConfig( + feature_names=raw_cand_sid_names, + max_item_ids=max_item_ids, + min_item_ids=min_item_ids, + max_sequence_length=max_candidate_length, + is_jagged=False, + ), + ] + return [ + GPTSIDBatch.random( + batch_size=batchsize, + feature_configs=raw_feature_configs, + raw_hist_sid_names=raw_hist_sid_names, + raw_cand_sid_names=raw_cand_sid_names, + combined_history_feature_name=combined_history_feature_name, + combined_candidate_feature_name=combined_candidate_feature_name, + contextual_feature_names=contextual_feature_names, + device=torch.cuda.current_device(), + ) + for _ in range(num_batches) + ] + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [512]) +@pytest.mark.parametrize("num_attention_heads", [4]) +@pytest.mark.parametrize("kv_channels", [128]) +@pytest.mark.parametrize("num_layers", [1]) +@pytest.mark.parametrize("max_history_length", [128]) +@pytest.mark.parametrize("codebook_sizes", [[128, 128, 128, 128], [256, 256, 256]]) +def test_model_smoke( + dtype, + hidden_size, + num_attention_heads, + kv_channels, + num_layers, + max_history_length, + codebook_sizes, +): + # we now only support max_candidate_length = 1 for now + max_candidate_length = 1 + num_hierarchies = len(codebook_sizes) + init.initialize_distributed() + init.initialize_model_parallel(1) # tp1 + init.set_random_seed(1234) + history_sid_feature_name = "hist_sids" + candidate_sid_feature_name = "cand_sids" + codebook_embedding_config = ShardedEmbeddingConfig( + feature_names=[history_sid_feature_name, candidate_sid_feature_name], + table_name="codebook", + vocab_size=sum(codebook_sizes), + dim=hidden_size, + sharding_type="data_parallel", + ) + batchsize = 128 + num_batches = 10 + batches = generate_batches( + batchsize=batchsize, + num_batches=num_batches, + max_history_length=max_history_length, + max_candidate_length=max_candidate_length, + codebook_sizes=codebook_sizes, + combined_history_feature_name=history_sid_feature_name, + combined_candidate_feature_name=candidate_sid_feature_name, + contextual_feature_names=[], + ) + with init.auto_destroy_global_state(): + model, optimizer = create_sid_gr_model_and_optimizer( + dtype=dtype, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + kv_channels=kv_channels, + num_layers=num_layers, + num_hierarchies=num_hierarchies, + codebook_embedding_config=codebook_embedding_config, + codebook_sizes=codebook_sizes, + ) + optimizer.reload_model_params() + + for batch in batches: + batch.to(torch.cuda.current_device()) + output = model(batch) + # each sequence corresponds to one loss. + loss, logits = output + assert ( + loss.shape[0] + == batch.features[batch.candidate_feature_name].offsets()[-1] + ) + assert output is not None + loss.sum().backward() + + +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("hidden_size", [512]) +@pytest.mark.parametrize("num_attention_heads", [4]) +@pytest.mark.parametrize("kv_channels", [128]) +@pytest.mark.parametrize("num_layers", [1]) +@pytest.mark.parametrize("max_history_length", [128]) +@pytest.mark.parametrize("codebook_sizes", [[128, 128, 128, 128], [256, 256, 256]]) +def test_model_decoder_step( + dtype, + hidden_size, + num_attention_heads, + kv_channels, + num_layers, + max_history_length, + codebook_sizes, +): + num_hierarchies = len(codebook_sizes) + init.initialize_distributed() + init.initialize_model_parallel(1) + init.set_random_seed(1234) + history_sid_feature_name = "hist_sids" + candidate_sid_feature_name = "cand_sids" + codebook_embedding_config = ShardedEmbeddingConfig( + feature_names=[history_sid_feature_name, candidate_sid_feature_name], + table_name="codebook", + vocab_size=sum(codebook_sizes), + dim=hidden_size, + sharding_type="data_parallel", + ) + batch_size = 1 + with init.auto_destroy_global_state(): + model, optimizer = create_sid_gr_model_and_optimizer( + dtype=dtype, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + kv_channels=kv_channels, + num_layers=num_layers, + num_hierarchies=num_hierarchies, + codebook_embedding_config=codebook_embedding_config, + codebook_sizes=codebook_sizes, + ) + optimizer.reload_model_params() + model = get_unwrapped_module(model) + # inference mode + model.eval() + for i in range(10): + history_hiddens = torch.randn( + batch_size, + max_history_length, + hidden_size, + device=torch.cuda.current_device(), + dtype=dtype, + ) + history_offsets = ( + torch.arange( + 0, + batch_size + 1, + device=torch.cuda.current_device(), + dtype=torch.long, + ) + * max_history_length + ) + history_max_seqlen = max_history_length + + # each history corresponds to one candidate + candidates_hiddens = torch.randn( + batch_size, + num_hierarchies, + hidden_size, + device=torch.cuda.current_device(), + dtype=dtype, + ) + candidate_offsets = ( + torch.arange( + 0, + batch_size + 1, + device=torch.cuda.current_device(), + dtype=torch.long, + ) + * num_hierarchies + ) + candidate_max_seqlen = num_hierarchies + + input_hidden_states = torch.cat( + [history_hiddens, candidates_hiddens], dim=1 + ) + input_offsets = history_offsets + candidate_offsets + input_max_seqlen = history_max_seqlen + candidate_max_seqlen + + # decoding in one shot + output = model.decoder_step( + input_hidden_states.view(-1, hidden_size), + input_offsets, + input_max_seqlen, + ).view(batch_size, input_max_seqlen, -1) + candidates_logits = output.view(batch_size, input_max_seqlen, -1)[ + :, history_max_seqlen:, : + ] + ref_probs = [] + for h in range(num_hierarchies): + mlp = model._decoder_mlp[h] + tuple_or_tensor = mlp(candidates_logits[:, h, :]) + logits = ( + tuple_or_tensor[0] + if isinstance(tuple_or_tensor, tuple) + else tuple_or_tensor + ) + ref_probs.append(torch.nn.functional.softmax(logits.float(), dim=-1)) + + # decoding one by one + for h in range(1, num_hierarchies + 1): + # h = num_hierarchies + # h = num_hierarchies + prefix_candidates_hiddens = candidates_hiddens[:, :h, :] + prefix_candidate_offsets = ( + torch.arange( + 0, + batch_size + 1, + device=torch.cuda.current_device(), + dtype=torch.long, + ) + * h + ) + prefix_candidate_max_seqlen = h + prefix_input_hidden_states = torch.cat( + [history_hiddens, prefix_candidates_hiddens], dim=1 + ) + prefix_input_offsets = history_offsets + prefix_candidate_offsets + prefix_input_max_seqlen = ( + history_max_seqlen + prefix_candidate_max_seqlen + ) + + prefix_output = model.decoder_step( + prefix_input_hidden_states.view(-1, hidden_size), + prefix_input_offsets, + prefix_input_max_seqlen, + ).view(batch_size, prefix_input_max_seqlen, -1) + prefix_candidates_logits = prefix_output.view( + batch_size, prefix_input_max_seqlen, -1 + )[:, history_max_seqlen:, :] + for hh in range(0, h): + ref_prob = ref_probs[hh] + tuple_or_tensor = model._decoder_mlp[hh]( + prefix_candidates_logits[:, hh, :] + ) + prob = ( + tuple_or_tensor[0] + if isinstance(tuple_or_tensor, tuple) + else tuple_or_tensor + ) + prob = prob.float().softmax(dim=-1) + this_sorted_prob, this_sorted_indices = torch.sort( + prob, dim=-1, descending=True + ) + sorted_ref_prob, sorted_ref_indices = torch.sort( + ref_prob, dim=-1, descending=True + ) + # top 10? + the_same_order = ( + sorted_ref_indices[0:10] == this_sorted_indices[0:10] + ).all() + # import pdb; pdb.set_trace() + assert the_same_order + # assert torch.allclose(ref_prob, prob, atol=1e-4) diff --git a/examples/sid_gr/tests/test_utils.py b/examples/sid_gr/tests/test_utils.py new file mode 100644 index 000000000..cac3c2394 --- /dev/null +++ b/examples/sid_gr/tests/test_utils.py @@ -0,0 +1,75 @@ +from typing import List + +import torch +from commons.distributed.sharding import make_optimizer_and_shard +from commons.modules.embedding import ShardedEmbeddingConfig +from commons.optimizer import OptimizerParam +from configs.gpt_config import get_gpt_config +from model.gpt_model import SIDGRModel +from model.mcore_model_specs import get_gpt_decoder_block_spec + + +def create_sid_gr_model_and_optimizer( + dtype: torch.dtype, + hidden_size: int, + num_attention_heads: int, + kv_channels: int, + num_layers: int, + num_hierarchies: int, + codebook_embedding_config: ShardedEmbeddingConfig, + codebook_sizes: List[int], + should_add_sep_token: bool = False, + optimizer_type_str: str = "adam", + pipeline_type: str = "none", + device: torch.device = None, +): + decoder_config = get_gpt_config( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + kv_channels=kv_channels, + num_layers=num_layers, + dtype=dtype, + normalization="LayerNorm", + norm_epsilon=1e-5, + hidden_dropout=0.0, + tensor_model_parallel_size=1, + loss_on_history=False, + ) + # thd + causal + TE + transformer_decoder_layer_spec = get_gpt_decoder_block_spec( + decoder_config, + use_transformer_engine=False, + arbitrary_attention_mask=True, + ) + + sid_gr_model = SIDGRModel( + decoder_config=decoder_config, + codebook_embedding_config=codebook_embedding_config, + codebook_sizes=codebook_sizes, + num_hierarchies=num_hierarchies, + transformer_decoder_layer_spec=transformer_decoder_layer_spec, + should_add_sep_token=should_add_sep_token, + top_k_for_generation=10, + eval_metrics=("HitRate@2", "NDCG@10"), + share_lm_head_across_hierarchies=False, + ) + + optimizer_param = OptimizerParam( + optimizer_str=optimizer_type_str, + learning_rate=1e-3 if optimizer_type_str == "adam" else 1e-1, + adam_beta1=0.5, # larger beta1 for better debugging! + adam_beta2=0.999, + adam_eps=1e-8, + weight_decay=0.0, # decay is off for better debugging + ) + + model_train, dense_optimizer = make_optimizer_and_shard( + sid_gr_model, + config=decoder_config, + sparse_optimizer_param=optimizer_param, + dense_optimizer_param=optimizer_param, + pipeline_type=pipeline_type, + device=device, + ) + + return model_train, dense_optimizer diff --git a/examples/sid_gr/training/README.md b/examples/sid_gr/training/README.md new file mode 100644 index 000000000..0261f1e1e --- /dev/null +++ b/examples/sid_gr/training/README.md @@ -0,0 +1,86 @@ +# SID-GR Training Example + +This example implements a retrieval model with a standard Transformer decoder backbone. We use [Gin config](https://github.com/google/gin-config) to specify model hyperparameters and training configurations (e.g., dataset file paths, training steps). Currently, this implementation has been validated on the Amazon Beauty dataset. + +For detailed information about the Gin config interface and available parameters, please refer to the [inline documentation](../configs/sid_gin_config_args.py). + +## Important: `max_candidate_length` Configuration + +The `DatasetArgs.max_candidate_length` parameter controls which items in the sequence are used for loss calculation: + +- `max_candidate_length=1`: Only the **last item** in the sequence is used to calculate loss +- `max_candidate_length=0`: **All items except the first one** in the sequence are used to calculate loss + +## Dataset Preprocessing + +This example requires two types of dataset files: + +1. **PID-to-SID mapping tensor**: A PyTorch tensor file (loadable via `torch.load()`) +2. **Historical interaction sequences**: Parquet format files containing user-item interaction histories + +### PID-to-SID Mapping + +The PID-to-SID tokenization process (i.e., converting product IDs to semantic IDs) is **not included** in this example. Users must tokenize items separately before training. We recommend using [GRID](https://github.com/snap-research/GRID) for this purpose. + +**Requirements:** +- The mapping tensor must have shape: `[num_hierarchies, num_items]` +- The tensor should be compatible with `torch.load()` + +### Historical Sequence File + +Similar to other sequential models (e.g., HSTU), each user has a historical interaction sequence with items. This example uses the **Parquet format**, which offers superior file compression and I/O performance. + +**File structure:** +- Each row represents a user's interaction history +- The nested column contains the sequential history (variable length supported) +- Optional columns: user ID, sequence length +- A single user may span multiple rows with varying sequence lengths + +### Dataset Statistics + +| Dataset | # Users | Max Seq Len | Min Seq Len | Mean Seq Len | Median Seq Len | # Items | +|---------------|---------|-------------|-------------|--------------|----------------|---------| +| Amazon Beauty | 22,363 | 202 | 3 | 7 | 4 | 12,101 | + +## Jagged Tensor Support + +This implementation assumes variable-length (jagged) input sequences. We leverage [TorchRec Jagged Tensor](https://docs.pytorch.org/tutorials/intermediate/torchrec_intro_tutorial.html#torchrec-input-output-data-types) utilities to efficiently handle jagged tensor operations. + +**Note:** Jagged tensors are also referred to as the `THD` (Total, Head, Dim) layout in [Megatron-Core / Transformer-Engine](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/c/fused_attn.html#_CPPv4N15NVTE_QKV_Format8NVTE_THDE). + +## Getting Started + +### 1. Download Dataset + +Download the demo dataset from [Hugging Face](https://huggingface.co/datasets/DuDoddle/sid-amazon-beauty/) and ensure the data paths are correctly configured in the [Gin config file](../configs/sid_amazn.gin). + +### 2. Run Training + +The training entry point is [pretrain_sid_gr.py](./pretrain_sid_gr.py). + +**Command to train on Amazon Beauty dataset:** + +```bash +# Navigate to the sid_gr directory +cd /examples/sid_gr + +# Run training with 1 GPU +PYTHONPATH=${PYTHONPATH}:$(realpath ../) torchrun \ + --nproc_per_node 1 \ + --master_addr localhost \ + --master_port 6000 \ + ./training/pretrain_sid_gr.py \ + --gin-config-file ./configs/sid_amazn.gin +``` + +**Note:** Ensure your current working directory is `examples/sid_gr` before running the command. + +## Known Limitations + +⚠️ **This implementation is under active development.** The current version has not been fully optimized for performance. Known limitations include: + +- **Attention mechanism**: Currently using padded local SDPA (Scaled Dot-Product Attention) implementation in Megatron-Core with explicit attention masks +- **Beam search**: The beam search used during evaluation does not yet support KV cache optimization +- **Performance**: The model performance has not reached optimal levels + +We are actively working on addressing these limitations and improving overall efficiency. \ No newline at end of file diff --git a/examples/sid_gr/training/pretrain_sid_gr.py b/examples/sid_gr/training/pretrain_sid_gr.py new file mode 100644 index 000000000..984fb2642 --- /dev/null +++ b/examples/sid_gr/training/pretrain_sid_gr.py @@ -0,0 +1,174 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +from typing import Tuple + +import commons.utils.initialize as init +import gin +import torch +from commons.distributed.sharding import make_optimizer_and_shard +from commons.optimizer import OptimizerParam +from commons.pipeline.train_pipeline import ( + JaggedMegatronPrefetchTrainPipelineSparseDist, + JaggedMegatronTrainNonePipeline, + JaggedMegatronTrainPipelineSparseDist, +) +from commons.utils.logger import print_rank_0 +from configs.args_to_config import create_embedding_config +from configs.gpt_config import get_gpt_config +from configs.sid_gin_config_args import ( + DatasetArgs, + EmbeddingArgs, + NetworkArgs, + OptimizerArgs, + TensorModelParallelArgs, + TrainerArgs, +) +from datasets.sid_data_loader import get_train_and_test_data_loader +from model import get_sid_gr_model +from trainer.training import maybe_load_ckpts, train_with_pipeline + + +def get_dataset_and_embedding_args() -> Tuple[DatasetArgs, EmbeddingArgs]: + dataset_args = DatasetArgs() # type: ignore[call-arg] + + codebook_sizes = dataset_args.codebook_sizes + aggragated_codebook_size = sum(codebook_sizes) + # embedding feature names should match the dataset batch feature names + embedding_args = EmbeddingArgs( # sid tuples share one embedding table + feature_names=[ + dataset_args._history_sid_feature_name, + dataset_args._candidate_sid_feature_name, + ], # sid tuples share one embedding table + table_name="codebook", + item_vocab_size_or_capacity=aggragated_codebook_size, + sharding_type="data_parallel", + ) + + return dataset_args, embedding_args + + +def create_optimizer_params(optimizer_args: OptimizerArgs): + return OptimizerParam( + optimizer_str=optimizer_args.optimizer_str, + learning_rate=optimizer_args.learning_rate, + adam_beta1=optimizer_args.adam_beta1, + adam_beta2=optimizer_args.adam_beta2, + adam_eps=optimizer_args.adam_eps, + ) + + +def main(): + parser = argparse.ArgumentParser( + description="SID-GR Example Arguments", allow_abbrev=False + ) + parser.add_argument("--gin-config-file", type=str) + args = parser.parse_args() + gin.parse_config_file(args.gin_config_file) + trainer_args = TrainerArgs() + ( + dataset_args, + embedding_args, + ) = get_dataset_and_embedding_args() # auto-set by gin-config + network_args = NetworkArgs() + # this is a kinda hard code. + # when share_lm_head_across_hierarchies is True, we must deduplicate the label across hierarchy. + dataset_args.deduplicate_label_across_hierarchy = ( + network_args.share_lm_head_across_hierarchies + ) + + optimizer_args = OptimizerArgs() + tp_args = TensorModelParallelArgs() + + init.initialize_distributed() + init.initialize_model_parallel( + tensor_model_parallel_size=tp_args.tensor_model_parallel_size + ) + init.set_random_seed(trainer_args.seed) + free_memory, total_memory = torch.cuda.mem_get_info() + print_rank_0( + f"distributed env initialization done. Free cuda memory: {free_memory / (1024 ** 2):.2f} MB" + ) + gpt_config = get_gpt_config( + network_args.hidden_size, + network_args.kv_channels, + network_args.num_attention_heads, + network_args.num_layers, + torch.bfloat16, + hidden_dropout=network_args.hidden_dropout, + tensor_model_parallel_size=tp_args.tensor_model_parallel_size, + loss_on_history=dataset_args.max_candidate_length == 0, + ) + embedding_config = create_embedding_config(network_args.hidden_size, embedding_args) + model = get_sid_gr_model( + decoder_config=gpt_config, + codebook_embedding_config=embedding_config, + codebook_sizes=dataset_args.codebook_sizes, + num_hierarchies=dataset_args.num_hierarchies, + normalization="RMSNorm", + top_k_for_generation=trainer_args.top_k_for_generation, + eval_metrics=trainer_args.eval_metrics, + share_lm_head_across_hierarchies=network_args.share_lm_head_across_hierarchies, + ) + + optimizer_param = create_optimizer_params(optimizer_args) + model_train, dense_optimizer = make_optimizer_and_shard( + model, + config=gpt_config, + sparse_optimizer_param=optimizer_param, + dense_optimizer_param=optimizer_param, + dynamicemb_options_dict={}, + pipeline_type=trainer_args.pipeline_type, + ) + stateful_metric_module = None + train_dataloader, test_dataloader = get_train_and_test_data_loader( + dataset_args, trainer_args + ) + free_memory, total_memory = torch.cuda.mem_get_info() + print_rank_0( + f"model initialization done, start training. Free cuda memory: {free_memory / (1024 ** 2):.2f} MB" + ) + + maybe_load_ckpts(trainer_args.ckpt_load_dir, model, dense_optimizer) + if trainer_args.pipeline_type in ["prefetch", "native"]: + pipeline_factory = ( + JaggedMegatronPrefetchTrainPipelineSparseDist + if trainer_args.pipeline_type == "prefetch" + else JaggedMegatronTrainPipelineSparseDist + ) + pipeline = pipeline_factory( + model_train, + dense_optimizer, + device=torch.device("cuda", torch.cuda.current_device()), + ) + else: + pipeline = JaggedMegatronTrainNonePipeline( + model_train, + dense_optimizer, + device=torch.device("cuda", torch.cuda.current_device()), + ) + train_with_pipeline( + pipeline, + stateful_metric_module, + trainer_args, + train_dataloader, + test_dataloader, + dense_optimizer, + ) + init.destroy_global_state() + + +if __name__ == "__main__": + main() diff --git a/examples/sid_gr/training/trainer/__init__.py b/examples/sid_gr/training/trainer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/sid_gr/training/trainer/training.py b/examples/sid_gr/training/trainer/training.py new file mode 100644 index 000000000..284a10570 --- /dev/null +++ b/examples/sid_gr/training/trainer/training.py @@ -0,0 +1,213 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from datetime import datetime +from itertools import chain, count, cycle, islice +from typing import Iterator, Optional, Union + +import commons.checkpoint as checkpoint +import torch # pylint: disable-unused-import +import torch.distributed as dist +from commons.checkpoint import get_unwrapped_module +from commons.pipeline.train_pipeline import ( + JaggedMegatronPrefetchTrainPipelineSparseDist, + JaggedMegatronTrainNonePipeline, + JaggedMegatronTrainPipelineSparseDist, +) +from commons.utils.gpu_timer import GPUTimer +from commons.utils.logger import print_rank_0 +from commons.utils.stringify import stringify_dict +from configs.sid_gin_config_args import TrainerArgs +from model.gpt_model import SIDGRModel + +try: + from rich.progress import track +except ImportError: + track = lambda x, description: x + + +def evaluate( + pipeline: Union[ + JaggedMegatronPrefetchTrainPipelineSparseDist, + JaggedMegatronTrainNonePipeline, + JaggedMegatronTrainPipelineSparseDist, + ], + stateful_metric_module: torch.nn.Module, + trainer_args: TrainerArgs, + eval_loader: torch.utils.data.DataLoader, +): + iterated_eval_loader = islice(eval_loader, len(eval_loader)) + model = get_unwrapped_module(pipeline._model) + max_eval_iters = trainer_args.max_eval_iters or len(eval_loader) + max_eval_iters = min(max_eval_iters, len(eval_loader)) + for i in track( + range(max_eval_iters), total=max_eval_iters, description="Evaluating" + ): + # for batch in iterated_eval_loader: + batch = next(iterated_eval_loader) + batch = batch.to(torch.cuda.current_device()) + labels = batch.labels + generated_sids, log_probs = model.generate(batch) + model.evaluator(log_probs, generated_sids, labels) + compute_res = model.evaluator.compute() + # reset the evaluator for the next evaluation + model.evaluator.reset() + print_rank_0( + f"[evaluation iters:{max_eval_iters}, batch size:{trainer_args.eval_batch_size}], result:\n " + + stringify_dict(compute_res, prefix="Metrics", sep="\n ") + ) + + +def maybe_load_ckpts( + ckpt_load_dir: str, + model: SIDGRModel, + dense_optimizer: Optional[torch.optim.Optimizer] = None, +): + if ckpt_load_dir == "": + return + + assert os.path.exists( + ckpt_load_dir + ), f"ckpt_load_dir {ckpt_load_dir} does not exist" + + print_rank_0(f"Loading checkpoints from {ckpt_load_dir}") + checkpoint.load(ckpt_load_dir, model, dense_optimizer=dense_optimizer) + print_rank_0(f"Checkpoints loaded!!") + + +def save_ckpts( + ckpt_save_dir: str, + model: SIDGRModel, + dense_optimizer: Optional[torch.optim.Optimizer] = None, +): + print_rank_0(f"Saving checkpoints to {ckpt_save_dir}") + import shutil + + if dist.get_rank() == 0: + if os.path.exists(ckpt_save_dir): + shutil.rmtree(ckpt_save_dir) + try: + os.makedirs(ckpt_save_dir, exist_ok=True) + except Exception as e: + raise Exception("can't build path:", ckpt_save_dir) from e + dist.barrier(device_ids=[torch.cuda.current_device()]) + checkpoint.save(ckpt_save_dir, model, dense_optimizer=dense_optimizer) + print_rank_0(f"Checkpoints saved!!") + + +# TODO. Use itertools.batched if python version is 3.12+ +def batched(it: Iterator, n: int): + assert n >= 1 + for x in it: + yield chain((x,), islice(it, n - 1)) + + +def train_with_pipeline( + pipeline: Union[ + JaggedMegatronPrefetchTrainPipelineSparseDist, + JaggedMegatronTrainNonePipeline, + JaggedMegatronTrainPipelineSparseDist, + ], + stateful_metric_module: torch.nn.Module, + trainer_args: TrainerArgs, + train_loader: torch.utils.data.DataLoader, + eval_loader: torch.utils.data.DataLoader, + dense_optimizer: torch.optim.Optimizer, +): + gpu_timer = GPUTimer() + max_train_iters = trainer_args.max_train_iters or len(train_loader) + gpu_timer.start() + last_td = 0 + + # using a tensor on gpu to avoid d2h copy + tokens_logged = torch.zeros(1).cuda().float() + # limit the number of iters to max_train_iters + # we support max_train_iters > n_batches, i.e. multiple epochs + train_loader_iter = islice(cycle(iter(train_loader)), max_train_iters) + + # every eval iter + n = trainer_args.eval_interval if trainer_args.eval_interval else max_train_iters + # data loader is split into num_iters / eval_interval (iters) slices where each slice contains n batches + iter_slices = batched(train_loader_iter, n) + start_iter = 0 + pipeline._model.train() + # note that torch profiler is exclusive with cuda profiler on GPU side. + torch_profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + # record_shapes=True, + with_stack=True, + with_flops=True, + ) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + for batched_iterator in iter_slices: + # for one slice(every eval interval) + for train_iter in count(start_iter): + if trainer_args.profile and train_iter == trainer_args.profile_step_start: + dist.barrier(device_ids=[torch.cuda.current_device()]) + torch.cuda.profiler.start() + torch_profiler.start() + if trainer_args.profile and train_iter == trainer_args.profile_step_end: + torch.cuda.profiler.stop() + dist.barrier(device_ids=[torch.cuda.current_device()]) + torch_profiler.stop() + trace_name = f"sid_gr_trace_{timestamp}.json" + trace_file = os.path.join(trainer_args.log_dir, trace_name) + torch_profiler.export_chrome_trace(trace_file) + if ( + train_iter * trainer_args.ckpt_save_interval > 0 + and train_iter % trainer_args.ckpt_save_interval == 0 + ): + save_path = os.path.join( + trainer_args.ckpt_save_dir, f"iter{train_iter}" + ) + save_ckpts(save_path, pipeline._model, dense_optimizer) + try: + torch.cuda.nvtx.range_push(f"step {train_iter}") + reporting_loss, logits = pipeline.progress( + batched_iterator + ) # Exception raised here + tokens_logged += reporting_loss[1] + if ( + train_iter > 0 and (train_iter + 1) % trainer_args.log_interval == 0 + ) or trainer_args.log_interval == 1: + gpu_timer.stop() + cur_td = gpu_timer.elapsed_time() - last_td + print_rank_0( + f"[train] [iter {train_iter}, tokens {int(tokens_logged.item())}, elapsed_time {cur_td:.2f} ms]: loss {reporting_loss[0] / reporting_loss[1]:.6f}" + ) + last_td = cur_td + last_td + tokens_logged.zero_() + # evaluate the model + if ( + train_iter > 0 and train_iter % trainer_args.eval_interval == 0 + ) or trainer_args.eval_interval == 1: + pipeline._model.eval() + evaluate( + pipeline, + stateful_metric_module, + trainer_args=trainer_args, + eval_loader=eval_loader, + ) + pipeline._model.train() + + except StopIteration: + start_iter = train_iter + break + finally: + # log + torch.cuda.nvtx.range_pop() diff --git a/pyproject.toml b/pyproject.toml index 879275432..2ac3fc478 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,5 +4,5 @@ exclude = [ "examples/hstu/ops/triton_ops/*", "examples/hstu/ops/fused_hstu_op.py", "corelib/*", - "examples/hstu/pipeline/*" + "examples/commons/pipeline/*", ] \ No newline at end of file