Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
a8f55ba
Move pipeline to commons
JacoCheung Nov 27, 2025
0250e5b
Move jagged concat ops and embedding to common
JacoCheung Dec 2, 2025
64ce0e5
Move ShardedEmbeddingConfig to commons
JacoCheung Dec 4, 2025
4b0cfca
Add sid gr model definition
JacoCheung Dec 4, 2025
5541234
Move distributed to commons
JacoCheung Dec 5, 2025
bbc4429
Move triton_ops.common to commons
JacoCheung Dec 8, 2025
c86417d
Add runnable GPTSID GR model
JacoCheung Dec 8, 2025
97a4b94
Add training pipeline and random dataset trainable
JacoCheung Dec 8, 2025
265d5a2
Restore mypy check
JacoCheung Dec 9, 2025
a46a445
Add disk dataset/dataloader and its utest
JacoCheung Dec 9, 2025
acea303
Separate history_seqlen and candidate_seqlen
JacoCheung Dec 9, 2025
7c54cd1
Fix the dataset and emb args feat name mismatch
JacoCheung Dec 10, 2025
2bb9090
Sort samples by userid for debugging
JacoCheung Dec 10, 2025
4e4118e
Enable arbitrary mask with local attention impl
JacoCheung Dec 12, 2025
1f814ea
Add beam search functionality
JacoCheung Dec 16, 2025
a7a0316
Add beam search individual module and eval metric test
JacoCheung Dec 18, 2025
05a498b
Add beam history sids check and eval metrics to gptmodel
JacoCheung Dec 18, 2025
6a05e79
Support dynamic beam and handle case when topk > num_candidates
JacoCheung Dec 18, 2025
580c37f
Fix bos split bug and add more hist info in bEamSearch
JacoCheung Dec 19, 2025
070e938
Fix mask def for mcore and mask construction error, make model overfi…
JacoCheung Dec 25, 2025
260e986
Add RMSNorm
JacoCheung Dec 30, 2025
b9ffa87
Fix attention mask definition and enable loss on history
JacoCheung Jan 3, 2026
59ff4c7
Add incomplete batch dataset test
JacoCheung Jan 4, 2026
d4a93e4
Fix incomplete eval batch
JacoCheung Jan 4, 2026
240644d
Fix generation mask
JacoCheung Jan 5, 2026
7b80791
Fix config for debugging eval
JacoCheung Jan 6, 2026
8a9e3a3
Enable single shared lm_head or individual lm_head across hierarchies
JacoCheung Jan 7, 2026
1af0f04
Add license header
JacoCheung Jan 7, 2026
482b277
Add sid_gr README
JacoCheung Jan 8, 2026
caa9ed9
Fix all utests of sid gr and update sid_amazn config
JacoCheung Jan 8, 2026
7ceeba1
Adjust the img size of sid ReadMe
JacoCheung Jan 8, 2026
6d55440
Lessen the max_train_iters
JacoCheung Jan 8, 2026
eae52ab
Rename data -> datasets
JacoCheung Jan 9, 2026
01c39f7
Rename hstu/dataset -> hstu/datasets and fix commons import error in …
JacoCheung Jan 9, 2026
74b1a0c
Restore HKV commit
JacoCheung Jan 9, 2026
6007d51
Remove sid/ops and move training link to head of README
JacoCheung Jan 9, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ NVIDIA RecSys Examples is a collection of optimized recommender models and compo

The project includes:
- Examples for large-scale HSTU ranking and retrieval models through [TorchRec](https://github.com/pytorch/torchrec) and [Megatron-Core](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core) integration
- Examples for semantic-id based retrieval model through [TorchRec](https://github.com/pytorch/torchrec) and [Megatron-Core](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core) integration
- HSTU (Hierarchical Sequential Transduction Unit) attention operator support
- Dynamic Embeddings with GPU acceleration

Expand Down Expand Up @@ -47,6 +48,7 @@ For more detailed release notes, please refer our [releases](https://github.com/
# Get Started
The examples we supported:
- [HSTU recommender examples](./examples/hstu/README.md)
- [SID based generative recommender examples](./examples/sid_gr/README.md)

# Contribution Guidelines
Please see our [contributing guidelines](./CONTRIBUTING.md) for details on how to contribute to this project.
Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,5 @@ RUN cd /workspace/recsys-examples/corelib/hstu && \
cd hopper && \
HSTU_DISABLE_ARBITRARY=TRUE HSTU_DISABLE_SM8x=TRUE HSTU_DISABLE_LOCAL=TRUE HSTU_DISABLE_RAB=TRUE HSTU_DISABLE_DELTA_Q=FALSE HSTU_DISABLE_DRAB=TRUE pip install .

RUN cd /workspace/recsys-examples/examples/hstu && \
RUN cd /workspace/recsys-examples/examples/commons && \
TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0" python3 setup.py install
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from typing import Dict, Optional, Union
from typing import Any, Dict, Optional, Union

import torch
import torch.distributed as dist
from dataset.utils import RankingBatch, RetrievalBatch
from megatron.core import parallel_state
from ops.collective_ops import (
from commons.ops.collective_ops import (
gather_along_first_dim,
gatherv_along_first_dim,
jagged_tensor_allgather,
keyed_jagged_tensor_allgather,
)
from ops.grad_scaling import grad_scaling
from commons.ops.grad_scaling import grad_scaling
from megatron.core import parallel_state
from torchrec.sparse.jagged_tensor import JaggedTensor


Expand All @@ -35,9 +34,7 @@ def jt_dict_grad_scaling_and_allgather(


# The features is a kjt, input to embedding module.
def dmp_batch_to_tp(
batch: Union[RetrievalBatch, RankingBatch], exclude_features: bool = True
) -> Union[RetrievalBatch, RankingBatch]:
def dmp_batch_to_tp(batch: Any, exclude_features: bool = True) -> Any:
tp_pg = parallel_state.get_tensor_model_parallel_group()
tp_size = dist.get_world_size(group=tp_pg)
batch_cls = type(batch)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
import torch
import torch.distributed as dist
import torchrec
from configs.task_config import OptimizerParam

# import our own finalize model grads
from distributed.finalize_model_grads import finalize_model_grads
from commons.distributed.finalize_model_grads import finalize_model_grads
from commons.modules.embedding import DataParallelEmbeddingCollection
from commons.optimizer import OptimizerParam
from dynamicemb import DynamicEmbTableOptions
from dynamicemb.get_planner import get_planner
from dynamicemb.planner import (
Expand All @@ -40,7 +41,6 @@
from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer
from megatron.core.transformer import TransformerConfig
from megatron.core.transformer.module import Float16Module
from modules.embedding import DataParallelEmbeddingCollection
from torch import distributed as dist
from torch.distributed.optim import (
_apply_optimizer_in_backward as apply_optimizer_in_backward,
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
import copy
import os
from dataclasses import dataclass

# pyre-strict
from typing import Any, Dict, Iterator, List, Optional, Tuple
Expand All @@ -23,7 +24,6 @@
import torch.fx
import torch.nn as nn
from commons.utils.nvtx_op import output_nvtx_hook, register_setter_and_getter_for_nvtx
from configs.task_config import ShardedEmbeddingConfig
from dynamicemb.planner import (
DynamicEmbeddingShardingPlanner as DynamicEmbeddingShardingPlanner,
)
Expand Down Expand Up @@ -52,6 +52,40 @@
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor


@dataclass
class ShardedEmbeddingConfig:
"""
Configuration for sharded embeddings with sharding type. Inherits from BaseShardedEmbeddingConfig.

Args:
config (EmbeddingConfig): The embedding configuration.
sharding_type (str): The type of sharding, ``'data_parallel'`` | ``'model_parallel'``.
"""

"""
Base configuration for sharded embeddings.

Args:
feature_names (List[str]): The name of the features in this embedding.
table_name (str): The name of the table.
vocab_size (int): The size of the vocabulary.
dim (int): The dimension size of the embeddings.
sharding_type (str): The type of sharding, ``'data_parallel'`` | ``'model_parallel'``.
"""

feature_names: List[str]
table_name: str
vocab_size: int
dim: int
sharding_type: str

def __post_init__(self):
assert self.sharding_type in [
"data_parallel",
"model_parallel",
], "sharding type should be data_parallel or model_parallel"


def create_data_parallel_sharding_infos_by_sharding(
module: EmbeddingCollectionInterface,
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
Expand Down Expand Up @@ -132,6 +166,7 @@ class DataParallelEmbeddingCollection(torch.nn.Module):
"""
Sharded implementation of `EmbeddingCollection`.
This is part of the public API to allow for manual data dist pipelining.
We re-implement the DP embedding so that it can be wrapped by Megatron DDP.
"""

def __init__(
Expand Down Expand Up @@ -354,14 +389,19 @@ def forward(self, kjt: KeyedJaggedTensor) -> Dict[str, JaggedTensor]:
Returns:
`Dict[str, JaggedTensor <https://pytorch.org/torchrec/concepts.html#jaggedtensor>]`: The output embeddings.
"""
mp_embeddings_awaitables = self._model_parallel_embedding_collection(kjt)
assert not (
self._model_parallel_embedding_collection is None
and self._data_parallel_embedding_collection is None
), "either model_parallel_embedding_collection or data_parallel_embedding_collection must be not None"
embeddings: Dict[str, JaggedTensor] = {}
if self._model_parallel_embedding_collection is not None:
mp_embeddings_awaitables = self._model_parallel_embedding_collection(kjt)
embeddings = {**embeddings, **(mp_embeddings_awaitables.wait())}
if self._data_parallel_embedding_collection is not None:
with torch.cuda.stream(self._side_stream):
dp_embeddings = self._data_parallel_embedding_collection(kjt)
torch.cuda.current_stream().wait_stream(self._side_stream)
embeddings = {**mp_embeddings_awaitables.wait(), **dp_embeddings}
else:
embeddings = mp_embeddings_awaitables.wait()
embeddings = {**embeddings, **dp_embeddings}
return embeddings

def export_local_embedding(self, table_name: str) -> Tuple[np.ndarray, np.ndarray]:
Expand All @@ -381,7 +421,7 @@ def export_local_embedding(self, table_name: str) -> Tuple[np.ndarray, np.ndarra
Example:
>>> # assume we have 2 ranks
>>> import torch
>>> from modules.embedding import ShardedEmbedding
>>> from commons.modules.embedding import ShardedEmbedding
>>> from configs.task_config import ShardedEmbeddingConfig
>>> from commons.utils.initialize as init
>>> from commons.utils.logger import print_rank_0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import torch
import torch.distributed as dist
from ops.length_to_offsets import length_to_complete_offsets
from commons.ops.length_to_offsets import length_to_complete_offsets
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import hstu_cuda_ops
import torch
from ops.length_to_offsets import length_to_complete_offsets
from commons.ops.length_to_offsets import length_to_complete_offsets


class _JaggedTensorOpFunction(torch.autograd.Function):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import dataclasses
from dataclasses import dataclass
from enum import Enum, unique
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Tuple, Union

import torch

Expand Down Expand Up @@ -129,55 +129,6 @@ class HammerKernel(Enum): # type: ignore[no-redef]
TRITON_CC = "TRITON_CC"


class GRModuleBase(torch.nn.Module):
_is_inference: bool
_use_triton_cc: bool
_custom_kernel: bool
_hammer_kernel: Optional[HammerKernel] = None

def __init__(
self,
is_inference: bool,
use_triton_cc: bool = True,
custom_kernel: bool = True,
hammer_kernel: Optional[HammerKernel] = None,
) -> None:
super().__init__()
self._is_inference = is_inference
self._use_triton_cc = use_triton_cc
self._custom_kernel = custom_kernel
self._hammer_kernel = hammer_kernel

def hammer_kernel(self) -> HammerKernel:
kernel = self._hammer_kernel
if kernel is not None:
return kernel
if self._custom_kernel:
if self._is_inference and self._use_triton_cc:
return HammerKernel.TRITON_CC
else:
return HammerKernel.TRITON
else:
return HammerKernel.PYTORCH

# pyre-ignore[2]
def recursive_setattr(self, name: str, value: Any) -> None:
for _, module in self.named_modules():
if hasattr(module, name):
setattr(module, name, value)

@property
def predict_mode(self) -> bool:
return self._is_inference

@property
def eval_mode(self) -> bool:
return (not self._is_inference) and (not self.training)

@property
def train_mode(self) -> bool:
return (not self._is_inference) and self.training


def generate_sparse_seq_len(
size: int,
Expand Down
Loading