Skip to content

Commit 6a5be94

Browse files
authored
add preprocessing mlp for hstu (#98)
* add item mpl and contextual mlp * fix * fix ci * fix
1 parent 2525ee4 commit 6a5be94

20 files changed

+517
-397
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,5 @@ embedding_checkpoint/
2929
.idea/
3030
.DS_Store
3131
*.pickle
32-
*.xlsx
32+
*.xlsx
33+
pcie_lookup_poc/*

examples/commons/utils/initialize.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@
1616
import os
1717

1818
import torch
19-
from megatron.core import parallel_state, tensor_parallel
19+
20+
try:
21+
from megatron.core import parallel_state, tensor_parallel
22+
except ImportError:
23+
print("megatron.core is not installed, training is not supported.")
24+
parallel_state = None
25+
tensor_parallel = None
2026

2127

2228
def initialize_single_rank():

examples/hstu/benchmark_ranking.gin

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ BenchmarkDatasetArgs.contextual_feature_names=[]
3838
BenchmarkDatasetArgs.action_feature_name='action'
3939
BenchmarkDatasetArgs.max_num_candidates=0
4040

41+
NetworkArgs.item_embedding_dim = 128
42+
NetworkArgs.contextual_embedding_dim = 256
4143
NetworkArgs.num_layers = 8
4244
NetworkArgs.num_attention_heads = 4
4345
NetworkArgs.hidden_size = 1024

examples/hstu/configs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .hstu_config import (
33
HSTUConfig,
44
HSTULayerType,
5+
HSTUPreprocessingConfig,
56
KernelBackend,
67
PositionEncodingConfig,
78
get_hstu_config,
@@ -29,6 +30,7 @@
2930
"task_config",
3031
"ConfigType",
3132
"PositionEncodingConfig",
33+
"HSTUPreprocessingConfig",
3234
"HSTUConfig",
3335
"get_hstu_config",
3436
"RankingConfig",

examples/hstu/configs/hstu_config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,19 @@ class PositionEncodingConfig:
7777
use_time_encoding: bool
7878

7979

80+
@dataclass
81+
class HSTUPreprocessingConfig:
82+
item_embedding_dim: int
83+
contextual_embedding_dim: int
84+
85+
8086
@dataclass
8187
class HSTUConfig(TransformerConfig):
8288
"""
8389
HSTUConfig is a configuration data class for the HSTU model, inheriting from TransformerConfig.
8490
8591
Args:
92+
hstu_preprocessing_config (HSTUPreprocessingConfig): HSTU preprocessing config. Defaults to None.
8693
position_encoding_config (PositionEncodingConfig): Position embedding config. Defaults to None.
8794
is_causal (bool): Indicates if the model is causal. Defaults to True.
8895
enable_relative_attention_bias (bool): Flag to enable relative attention bias. Defaults to False.
@@ -97,6 +104,7 @@ class HSTUConfig(TransformerConfig):
97104
recompute_input_silu (bool): Flag to enable recompute input silu. Defaults to False.
98105
"""
99106

107+
hstu_preprocessing_config: Optional[HSTUPreprocessingConfig] = None
100108
position_encoding_config: Optional[PositionEncodingConfig] = None
101109
is_causal: bool = True
102110
enable_relative_attention_bias: bool = False
@@ -131,6 +139,7 @@ def get_hstu_config(
131139
num_attention_heads,
132140
num_layers,
133141
dtype,
142+
hstu_preprocessing_config: Optional[HSTUPreprocessingConfig] = None,
134143
position_encoding_config: Optional[PositionEncodingConfig] = None,
135144
hidden_dropout=0.2,
136145
norm_epsilon=1e-5,
@@ -156,6 +165,7 @@ def get_hstu_config(
156165
num_attention_heads (int): Number of attention heads.
157166
num_layers (int): Number of attention layers.
158167
dtype (torch.dtype): Data type (e.g., torch.float16).
168+
hstu_preprocessing_config (Optional[HSTUPreprocessingConfig], optional): HSTU preprocessing config. Defaults to None.
159169
position_encoding_config (Optional[PositionEncodingConfig], optional): Position embedding config. Defaults to None.
160170
hidden_dropout (float, optional): Dropout rate for hidden layers. Defaults to 0.2.
161171
norm_epsilon (float, optional): Epsilon value for normalization. Defaults to 1e-5.
@@ -181,6 +191,7 @@ def get_hstu_config(
181191
async_wgrad_stream = None
182192
async_wgrad_event = None
183193
return HSTUConfig( # type: ignore
194+
hstu_preprocessing_config=hstu_preprocessing_config,
184195
position_encoding_config=position_encoding_config,
185196
hidden_size=hidden_size,
186197
kv_channels=kv_channels,

examples/hstu/configs/inference_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import torch
1919

20-
from .hstu_config import PositionEncodingConfig
20+
from .hstu_config import HSTUPreprocessingConfig, PositionEncodingConfig
2121

2222

2323
@dataclass
@@ -156,6 +156,7 @@ class InferenceHSTUConfig:
156156
is_causal: bool = True
157157
target_group_size: int = 1
158158
position_encoding_config: Optional[PositionEncodingConfig] = None
159+
hstu_preprocessing_config: Optional[HSTUPreprocessingConfig] = None
159160

160161
def __post_init__(self):
161162
assert self.is_causal

examples/hstu/dataset/utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,32 @@ def __post_init__(self):
7777
)
7878
assert isinstance(self.max_num_candidates, int)
7979

80+
def to(self, device: torch.device, non_blocking: bool = False) -> "Batch": # type: ignore
81+
"""
82+
Move the batch to the specified device.
83+
84+
Args:
85+
device (torch.device): The device to move the batch to.
86+
non_blocking (bool, optional): Whether to perform the move asynchronously. Defaults to False.
87+
88+
Returns:
89+
RankingBatch: The batch on the specified device.
90+
"""
91+
return Batch(
92+
features=self.features.to(device=device, non_blocking=non_blocking),
93+
batch_size=self.batch_size,
94+
feature_to_max_seqlen=self.feature_to_max_seqlen,
95+
contextual_feature_names=self.contextual_feature_names,
96+
item_feature_name=self.item_feature_name,
97+
action_feature_name=self.action_feature_name,
98+
max_num_candidates=self.max_num_candidates,
99+
num_candidates=self.num_candidates.to(
100+
device=device, non_blocking=non_blocking
101+
)
102+
if self.num_candidates is not None
103+
else None,
104+
)
105+
80106
@staticmethod
81107
def random(
82108
batch_size: int,

examples/hstu/model/inference_ranking_gr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def forward(
358358
):
359359
with torch.inference_mode():
360360
kvcache_metadata = self.prepare_kv_cache(batch, user_ids, user_start_pos)
361-
jagged_data = self._hstu_block.hstu_preprocess(
361+
jagged_data = self._hstu_block._preprocessor(
362362
embeddings=self._embedding_collection(batch.features),
363363
batch=batch,
364364
)
@@ -400,7 +400,7 @@ def forward(
400400
torch.cuda.current_stream()
401401
)
402402

403-
jagged_data = self._hstu_block.hstu_postprocess(jagged_data)
403+
jagged_data = self._hstu_block._postprocessor(jagged_data)
404404
jagged_item_logit = self._dense_module(jagged_data.values)
405405
self._offload_states = self.offload_kv_cache_async(
406406
user_ids, kvcache_metadata

examples/hstu/model/ranking_gr.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,11 @@ def __init__(
5050
self._hstu_config = hstu_config
5151
self._task_config = task_config
5252

53-
self._embedding_dim = hstu_config.hidden_size
54-
for ebc_config in task_config.embedding_configs:
55-
assert (
56-
ebc_config.dim == self._embedding_dim
57-
), "hstu layer hidden size should equal to embedding dim"
58-
5953
self._embedding_collection = ShardedEmbedding(task_config.embedding_configs)
6054

6155
self._hstu_block = HSTUBlock(hstu_config)
6256
self._mlp = MLP(
63-
self._embedding_dim,
57+
hstu_config.hidden_size,
6458
task_config.prediction_head_arch,
6559
task_config.prediction_head_act_type,
6660
task_config.prediction_head_bias,
Lines changed: 7 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
22

3-
from typing import Dict, Optional, Union
3+
from typing import Dict, Union
44

55
import torch
66
from commons.utils.nvtx_op import output_nvtx_hook
@@ -9,14 +9,9 @@
99
from megatron.core.transformer.module import MegatronModule
1010
from modules.debug.debug_hstu_layer import HSTULayer as DebugHSTULayer
1111
from modules.fused_hstu_layer import FusedHSTULayer
12+
from modules.hstu_processor import HSTUBlockPostprocessor, HSTUBlockPreprocessor
1213
from modules.jagged_data import JaggedData
1314
from modules.native_hstu_layer import HSTULayer as NativeHSTULayer
14-
from modules.position_encoder import HSTUPositionalEncoder
15-
from modules.utils import hstu_postprocess_embeddings, hstu_preprocess_embeddings
16-
from ops.triton_ops.triton_jagged import ( # type: ignore[attr-defined]
17-
triton_concat_2D_jagged,
18-
triton_split_2D_jagged,
19-
)
2015
from torchrec.sparse.jagged_tensor import JaggedTensor
2116

2217

@@ -39,16 +34,9 @@ def __init__(
3934
if self.config.fp16:
4035
self._training_dtype = torch.float16
4136

42-
self._positional_encoder: Optional[HSTUPositionalEncoder] = None
43-
if config.position_encoding_config is not None:
44-
self._positional_encoder = HSTUPositionalEncoder(
45-
num_position_buckets=config.position_encoding_config.num_position_buckets,
46-
num_time_buckets=config.position_encoding_config.num_time_buckets,
47-
embedding_dim=config.hidden_size,
48-
is_inference=False,
49-
use_time_encoding=config.position_encoding_config.use_time_encoding,
50-
training_dtype=self._training_dtype,
51-
)
37+
self._preprocessor = HSTUBlockPreprocessor(config, is_inference=False)
38+
self._postprocessor = HSTUBlockPostprocessor(is_inference=False)
39+
5240
HSTULayerImpl = (
5341
FusedHSTULayer
5442
if config.hstu_layer_type == HSTULayerType.FUSED
@@ -59,62 +47,6 @@ def __init__(
5947
self._attention_layers = torch.nn.ModuleList(
6048
[HSTULayerImpl(config) for l in range(self.config.num_layers)]
6149
)
62-
self._dropout_ratio = config.hidden_dropout
63-
64-
@output_nvtx_hook(nvtx_tag="HSTUBlock preprocess", hook_key_or_attr_name="values")
65-
def hstu_preprocess(
66-
self, embeddings: Dict[str, JaggedTensor], batch: RankingBatch
67-
) -> JaggedData:
68-
"""
69-
Preprocesses the embeddings for use in the HSTU architecture.
70-
71-
This method performs the following steps:
72-
1. **Interleaving**: If action embeddings are present, interleaves them with item embeddings.
73-
2. **Concatenation**: Concatenates contextual, item, and action embeddings for each sample, following the order specified in the batch.
74-
3. **Position Encoding**: Applies position encoding to the concatenated embeddings.
75-
76-
Args:
77-
embeddings (Dict[str, JaggedTensor]): A dictionary of embeddings where each key corresponds to a feature name and the value is a jagged tensor.
78-
batch (RankingBatch): The batch of ranking data.
79-
80-
Returns:
81-
JaggedData: The preprocessed jagged data, ready for further processing in the HSTU architecture.
82-
"""
83-
# Interleaving & concatenation
84-
jd = hstu_preprocess_embeddings(embeddings, batch, is_inference=False)
85-
86-
if self._positional_encoder is not None:
87-
jd.values = self._positional_encoder(
88-
max_seq_len=jd.max_seqlen,
89-
seq_lengths=jd.seqlen,
90-
seq_offsets=jd.seqlen_offsets,
91-
seq_timestamps=None,
92-
seq_embeddings=jd.values,
93-
num_targets=jd.num_candidates,
94-
)
95-
96-
jd.values = torch.nn.functional.dropout(
97-
jd.values,
98-
p=self._dropout_ratio,
99-
training=self.training,
100-
).to(self._training_dtype)
101-
return jd
102-
103-
@output_nvtx_hook(nvtx_tag="HSTUBlock postprocess", hook_key_or_attr_name="values")
104-
def hstu_postprocess(self, jd: JaggedData) -> JaggedData:
105-
"""
106-
Postprocess the output from the HSTU architecture.
107-
1. If max_num_candidates > 0, split and only keep last ``num_candidates`` embeddings as candidates embedding for further processing.
108-
2. Remove action embeddings if present. Only use item embedding for further processing.
109-
110-
Args:
111-
jd (JaggedData): The jagged data output from the HSTU architecture that needs further processing.
112-
113-
Returns:
114-
JaggedData: The postprocessed jagged data.
115-
"""
116-
117-
return hstu_postprocess_embeddings(jd, is_inference=False)
11850

11951
@output_nvtx_hook(nvtx_tag="HSTUBlock", hook_key_or_attr_name="values")
12052
def forward(
@@ -132,7 +64,7 @@ def forward(
13264
Returns:
13365
JaggedData: The output jagged data.
13466
"""
135-
jd = self.hstu_preprocess(embeddings, batch)
67+
jd = self._preprocessor(embeddings, batch)
13668
for hstu_layer in self._attention_layers:
13769
jd = hstu_layer(jd)
138-
return self.hstu_postprocess(jd)
70+
return self._postprocessor(jd)

0 commit comments

Comments
 (0)