Skip to content

Commit b24f307

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Add RawIdTrackerWrapper within TBE to access tracked ids and raw ids
Summary: This diff introduces RawIdTrackerWrapper, a wrapper class containing lookup and delete APIs registered during raw_ids_tracker initialization to access tracked ids and raw_ids. We needed to create a wrapper instead of passing in the tracker due to circular dependency issues since TBE is wrapped under DMP. internal This is needed to support MPZCH modules for Raw embedding streaming. Mode details : https://docs.google.com/document/d/1KEHwiXKLgXwRIdDFBYopjX3OiP3mRLM24Qkbiiu-TgE/edit?tab=t.0#bookmark=id.lhhgee2cs6ld Differential Revision: D84925177
1 parent febf3f0 commit b24f307

File tree

1 file changed

+63
-2
lines changed

1 file changed

+63
-2
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from math import sqrt
1818
from typing import (
1919
Any,
20+
Callable,
2021
cast,
2122
Dict,
2223
Generic,
@@ -70,6 +71,7 @@
7071
GroupedEmbeddingConfig,
7172
ShardedEmbeddingTable,
7273
)
74+
from torchrec.distributed.model_tracker.types import IndexedLookup
7375
from torchrec.distributed.shards_wrapper import LocalShardsWrapper
7476
from torchrec.distributed.types import (
7577
Shard,
@@ -80,6 +82,7 @@
8082
TensorProperties,
8183
)
8284
from torchrec.distributed.utils import append_prefix, none_throws
85+
8386
from torchrec.modules.embedding_configs import (
8487
CountBasedEvictionPolicy,
8588
CountTimestampMixedEvictionPolicy,
@@ -97,13 +100,30 @@
97100
)
98101
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
99102

103+
100104
logger: logging.Logger = logging.getLogger(__name__)
101105

102106
RES_ENABLED_TABLES_STR = "res_enabled_tables"
103107
RES_STORE_SHARDS_STR = "res_store_shards"
104108
ENABLE_RAW_EMBEDDING_STREAMING_STR = "enable_raw_embedding_streaming"
105109

106110

111+
class RawIdTrackerWrapper:
112+
def __init__(
113+
self,
114+
get_indexed_lookups: Callable[
115+
[List[str], Optional[str]],
116+
List[torch.Tensor],
117+
],
118+
delete: Callable[
119+
[int],
120+
None,
121+
],
122+
) -> None:
123+
self.get_indexed_lookups = get_indexed_lookups
124+
self.delete = delete
125+
126+
107127
def _populate_res_params(config: GroupedEmbeddingConfig) -> Tuple[bool, RESParams]:
108128
# populate res_params, which is used for raw embedding streaming
109129
# here only populates the params available in fused_params and TBE configs
@@ -2526,6 +2546,7 @@ def __init__(
25262546
self._lengths_per_emb: List[int] = []
25272547
self.table_name_to_count: Dict[str, int] = {}
25282548
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}
2549+
self._raw_id_tracker_wrapper: Optional[RawIdTrackerWrapper] = None
25292550

25302551
for idx, table_config in enumerate(self._config.embedding_tables):
25312552
self._local_rows.append(table_config.local_rows)
@@ -2579,7 +2600,26 @@ def init_parameters(self) -> None:
25792600
weight_init_max,
25802601
)
25812602

2582-
def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
2603+
def forward(
2604+
self,
2605+
features: KeyedJaggedTensor,
2606+
) -> torch.Tensor:
2607+
forward_args: Dict[str, Any] = {}
2608+
if self._raw_id_tracker_wrapper is not None:
2609+
if isinstance(self.emb_module, SplitTableBatchedEmbeddingBagsCodegen):
2610+
raw_id_tracker_wrapper = self._raw_id_tracker_wrapper
2611+
assert (
2612+
raw_id_tracker_wrapper is not None
2613+
), "self._raw_id_tracker_wrapper should not be None"
2614+
# TODO: Calling get_indexed_lookups(None) retrieves raw IDs for ALL tracked FQNs,
2615+
# including those this TBE doesn't own, and advances the shared consumer read index.
2616+
# While storage isn't deleted, advancing the index prevents re-reading, which blocks
2617+
# other TBEs from accessing their tracked raw IDs.
2618+
raw_ids_list = raw_id_tracker_wrapper.get_indexed_lookups(
2619+
features.keys(), self.emb_module.uuid
2620+
)
2621+
if raw_ids_list:
2622+
forward_args["hash_zch_identities"] = torch.cat(raw_ids_list)
25832623
weights = features.weights_or_none()
25842624
if weights is not None and not torch.is_floating_point(weights):
25852625
weights = None
@@ -2591,17 +2631,22 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
25912631
SSDTableBatchedEmbeddingBags,
25922632
),
25932633
):
2634+
forward_args["batch_size_per_feature_per_rank"] = (
2635+
features.stride_per_key_per_rank()
2636+
)
2637+
2638+
if len(forward_args) == 0:
25942639
return self.emb_module(
25952640
indices=features.values().long(),
25962641
offsets=features.offsets().long(),
25972642
per_sample_weights=weights,
2598-
batch_size_per_feature_per_rank=features.stride_per_key_per_rank(),
25992643
)
26002644
else:
26012645
return self.emb_module(
26022646
indices=features.values().long(),
26032647
offsets=features.offsets().long(),
26042648
per_sample_weights=weights,
2649+
**forward_args,
26052650
)
26062651

26072652
# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
@@ -2668,6 +2713,22 @@ def named_parameters_by_table(
26682713
for name, param in self._param_per_table.items():
26692714
yield name, param
26702715

2716+
def init_raw_id_tracker(
2717+
self,
2718+
get_indexed_lookups: Callable[
2719+
[List[str], Optional[str]],
2720+
List[torch.Tensor],
2721+
],
2722+
delete: Callable[
2723+
[int],
2724+
None,
2725+
],
2726+
) -> None:
2727+
if isinstance(self._emb_module, SplitTableBatchedEmbeddingBagsCodegen):
2728+
self._raw_id_tracker_wrapper = RawIdTrackerWrapper(
2729+
get_indexed_lookups, delete
2730+
)
2731+
26712732

26722733
class KeyValueEmbeddingBag(BaseBatchedEmbeddingBag[torch.Tensor], FusedOptimizerModule):
26732734
def __init__(

0 commit comments

Comments
 (0)