11# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
22
3- from typing import Dict , Optional , Union
3+ from typing import Dict , Union
44
55import torch
66from commons .utils .nvtx_op import output_nvtx_hook
99from megatron .core .transformer .module import MegatronModule
1010from modules .debug .debug_hstu_layer import HSTULayer as DebugHSTULayer
1111from modules .fused_hstu_layer import FusedHSTULayer
12+ from modules .hstu_processor import HSTUBlockPostprocessor , HSTUBlockPreprocessor
1213from modules .jagged_data import JaggedData
1314from modules .native_hstu_layer import HSTULayer as NativeHSTULayer
14- from modules .position_encoder import HSTUPositionalEncoder
15- from modules .utils import hstu_postprocess_embeddings , hstu_preprocess_embeddings
16- from ops .triton_ops .triton_jagged import ( # type: ignore[attr-defined]
17- triton_concat_2D_jagged ,
18- triton_split_2D_jagged ,
19- )
2015from torchrec .sparse .jagged_tensor import JaggedTensor
2116
2217
@@ -39,16 +34,9 @@ def __init__(
3934 if self .config .fp16 :
4035 self ._training_dtype = torch .float16
4136
42- self ._positional_encoder : Optional [HSTUPositionalEncoder ] = None
43- if config .position_encoding_config is not None :
44- self ._positional_encoder = HSTUPositionalEncoder (
45- num_position_buckets = config .position_encoding_config .num_position_buckets ,
46- num_time_buckets = config .position_encoding_config .num_time_buckets ,
47- embedding_dim = config .hidden_size ,
48- is_inference = False ,
49- use_time_encoding = config .position_encoding_config .use_time_encoding ,
50- training_dtype = self ._training_dtype ,
51- )
37+ self ._preprocessor = HSTUBlockPreprocessor (config , is_inference = False )
38+ self ._postprocessor = HSTUBlockPostprocessor (is_inference = False )
39+
5240 HSTULayerImpl = (
5341 FusedHSTULayer
5442 if config .hstu_layer_type == HSTULayerType .FUSED
@@ -59,62 +47,6 @@ def __init__(
5947 self ._attention_layers = torch .nn .ModuleList (
6048 [HSTULayerImpl (config ) for l in range (self .config .num_layers )]
6149 )
62- self ._dropout_ratio = config .hidden_dropout
63-
64- @output_nvtx_hook (nvtx_tag = "HSTUBlock preprocess" , hook_key_or_attr_name = "values" )
65- def hstu_preprocess (
66- self , embeddings : Dict [str , JaggedTensor ], batch : RankingBatch
67- ) -> JaggedData :
68- """
69- Preprocesses the embeddings for use in the HSTU architecture.
70-
71- This method performs the following steps:
72- 1. **Interleaving**: If action embeddings are present, interleaves them with item embeddings.
73- 2. **Concatenation**: Concatenates contextual, item, and action embeddings for each sample, following the order specified in the batch.
74- 3. **Position Encoding**: Applies position encoding to the concatenated embeddings.
75-
76- Args:
77- embeddings (Dict[str, JaggedTensor]): A dictionary of embeddings where each key corresponds to a feature name and the value is a jagged tensor.
78- batch (RankingBatch): The batch of ranking data.
79-
80- Returns:
81- JaggedData: The preprocessed jagged data, ready for further processing in the HSTU architecture.
82- """
83- # Interleaving & concatenation
84- jd = hstu_preprocess_embeddings (embeddings , batch , is_inference = False )
85-
86- if self ._positional_encoder is not None :
87- jd .values = self ._positional_encoder (
88- max_seq_len = jd .max_seqlen ,
89- seq_lengths = jd .seqlen ,
90- seq_offsets = jd .seqlen_offsets ,
91- seq_timestamps = None ,
92- seq_embeddings = jd .values ,
93- num_targets = jd .num_candidates ,
94- )
95-
96- jd .values = torch .nn .functional .dropout (
97- jd .values ,
98- p = self ._dropout_ratio ,
99- training = self .training ,
100- ).to (self ._training_dtype )
101- return jd
102-
103- @output_nvtx_hook (nvtx_tag = "HSTUBlock postprocess" , hook_key_or_attr_name = "values" )
104- def hstu_postprocess (self , jd : JaggedData ) -> JaggedData :
105- """
106- Postprocess the output from the HSTU architecture.
107- 1. If max_num_candidates > 0, split and only keep last ``num_candidates`` embeddings as candidates embedding for further processing.
108- 2. Remove action embeddings if present. Only use item embedding for further processing.
109-
110- Args:
111- jd (JaggedData): The jagged data output from the HSTU architecture that needs further processing.
112-
113- Returns:
114- JaggedData: The postprocessed jagged data.
115- """
116-
117- return hstu_postprocess_embeddings (jd , is_inference = False )
11850
11951 @output_nvtx_hook (nvtx_tag = "HSTUBlock" , hook_key_or_attr_name = "values" )
12052 def forward (
@@ -132,7 +64,7 @@ def forward(
13264 Returns:
13365 JaggedData: The output jagged data.
13466 """
135- jd = self .hstu_preprocess (embeddings , batch )
67+ jd = self ._preprocessor (embeddings , batch )
13668 for hstu_layer in self ._attention_layers :
13769 jd = hstu_layer (jd )
138- return self .hstu_postprocess (jd )
70+ return self ._postprocessor (jd )
0 commit comments