Skip to content

Commit 8ce55cf

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Add RawIdTrackerWrapper within TBE to access tracked ids and raw ids (#3506)
Summary: This diff introduces RawIdTrackerWrapper, a wrapper class containing lookup and delete APIs registered during raw_ids_tracker initialization to access tracked ids and raw_ids. We needed to create a wrapper instead of passing in the tracker due to circular dependency issues since TBE is wrapped under DMP. internal This is needed to support MPZCH modules for Raw embedding streaming. Mode details : https://docs.google.com/document/d/1KEHwiXKLgXwRIdDFBYopjX3OiP3mRLM24Qkbiiu-TgE/edit?tab=t.0#bookmark=id.lhhgee2cs6ld Reviewed By: chouxi Differential Revision: D84925177
1 parent 61eeb09 commit 8ce55cf

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from math import sqrt
1818
from typing import (
1919
Any,
20+
Callable,
2021
cast,
2122
Dict,
2223
Generic,
@@ -70,6 +71,7 @@
7071
GroupedEmbeddingConfig,
7172
ShardedEmbeddingTable,
7273
)
74+
from torchrec.distributed.model_tracker.types import IndexedLookup
7375
from torchrec.distributed.shards_wrapper import LocalShardsWrapper
7476
from torchrec.distributed.types import (
7577
Shard,
@@ -80,6 +82,7 @@
8082
TensorProperties,
8183
)
8284
from torchrec.distributed.utils import append_prefix, none_throws
85+
8386
from torchrec.modules.embedding_configs import (
8487
CountBasedEvictionPolicy,
8588
CountTimestampMixedEvictionPolicy,
@@ -97,13 +100,30 @@
97100
)
98101
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
99102

103+
100104
logger: logging.Logger = logging.getLogger(__name__)
101105

102106
RES_ENABLED_TABLES_STR = "res_enabled_tables"
103107
RES_STORE_SHARDS_STR = "res_store_shards"
104108
ENABLE_RAW_EMBEDDING_STREAMING_STR = "enable_raw_embedding_streaming"
105109

106110

111+
class RawIdTrackerWrapper:
112+
def __init__(
113+
self,
114+
get_indexed_lookups: Callable[
115+
[List[str], Optional[str]],
116+
Dict[str, List[torch.Tensor]],
117+
],
118+
delete: Callable[
119+
[int],
120+
None,
121+
],
122+
) -> None:
123+
self.get_indexed_lookups = get_indexed_lookups
124+
self.delete = delete
125+
126+
107127
def _populate_res_params(config: GroupedEmbeddingConfig) -> Tuple[bool, RESParams]:
108128
# populate res_params, which is used for raw embedding streaming
109129
# here only populates the params available in fused_params and TBE configs
@@ -2526,6 +2546,7 @@ def __init__(
25262546
self._lengths_per_emb: List[int] = []
25272547
self.table_name_to_count: Dict[str, int] = {}
25282548
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}
2549+
self._raw_id_tracker_wrapper: Optional[RawIdTrackerWrapper] = None
25292550

25302551
for idx, table_config in enumerate(self._config.embedding_tables):
25312552
self._local_rows.append(table_config.local_rows)
@@ -2668,6 +2689,22 @@ def named_parameters_by_table(
26682689
for name, param in self._param_per_table.items():
26692690
yield name, param
26702691

2692+
def init_raw_id_tracker(
2693+
self,
2694+
get_indexed_lookups: Callable[
2695+
[List[str], Optional[str]],
2696+
Dict[str, List[torch.Tensor]],
2697+
],
2698+
delete: Callable[
2699+
[int],
2700+
None,
2701+
],
2702+
) -> None:
2703+
if isinstance(self._emb_module, SplitTableBatchedEmbeddingBagsCodegen):
2704+
self._raw_id_tracker_wrapper = RawIdTrackerWrapper(
2705+
get_indexed_lookups, delete
2706+
)
2707+
26712708

26722709
class KeyValueEmbeddingBag(BaseBatchedEmbeddingBag[torch.Tensor], FusedOptimizerModule):
26732710
def __init__(

0 commit comments

Comments
 (0)