1717from math import sqrt
1818from typing import (
1919 Any ,
20+ Callable ,
2021 cast ,
2122 Dict ,
2223 Generic ,
7071 GroupedEmbeddingConfig ,
7172 ShardedEmbeddingTable ,
7273)
74+ from torchrec .distributed .model_tracker .types import IndexedLookup
7375from torchrec .distributed .shards_wrapper import LocalShardsWrapper
7476from torchrec .distributed .types import (
7577 Shard ,
8082 TensorProperties ,
8183)
8284from torchrec .distributed .utils import append_prefix , none_throws
85+
8386from torchrec .modules .embedding_configs import (
8487 CountBasedEvictionPolicy ,
8588 CountTimestampMixedEvictionPolicy ,
97100)
98101from torchrec .sparse .jagged_tensor import KeyedJaggedTensor
99102
103+
100104logger : logging .Logger = logging .getLogger (__name__ )
101105
102106RES_ENABLED_TABLES_STR = "res_enabled_tables"
103107RES_STORE_SHARDS_STR = "res_store_shards"
104108ENABLE_RAW_EMBEDDING_STREAMING_STR = "enable_raw_embedding_streaming"
105109
106110
111+ class RawIdTrackerWrapper :
112+ def __init__ (
113+ self ,
114+ get_indexed_lookups : Callable [
115+ [List [str ], Optional [str ]],
116+ List [torch .Tensor ],
117+ ],
118+ delete : Callable [
119+ [int ],
120+ None ,
121+ ],
122+ ) -> None :
123+ self .get_indexed_lookups = get_indexed_lookups
124+ self .delete = delete
125+
126+
107127def _populate_res_params (config : GroupedEmbeddingConfig ) -> Tuple [bool , RESParams ]:
108128 # populate res_params, which is used for raw embedding streaming
109129 # here only populates the params available in fused_params and TBE configs
@@ -2526,6 +2546,7 @@ def __init__(
25262546 self ._lengths_per_emb : List [int ] = []
25272547 self .table_name_to_count : Dict [str , int ] = {}
25282548 self ._param_per_table : Dict [str , TableBatchedEmbeddingSlice ] = {}
2549+ self ._raw_id_tracker_wrapper : Optional [RawIdTrackerWrapper ] = None
25292550
25302551 for idx , table_config in enumerate (self ._config .embedding_tables ):
25312552 self ._local_rows .append (table_config .local_rows )
@@ -2579,7 +2600,26 @@ def init_parameters(self) -> None:
25792600 weight_init_max ,
25802601 )
25812602
2582- def forward (self , features : KeyedJaggedTensor ) -> torch .Tensor :
2603+ def forward (
2604+ self ,
2605+ features : KeyedJaggedTensor ,
2606+ ) -> torch .Tensor :
2607+ forward_args : Dict [str , Any ] = {}
2608+ if self ._raw_id_tracker_wrapper is not None :
2609+ if isinstance (self .emb_module , SplitTableBatchedEmbeddingBagsCodegen ):
2610+ raw_id_tracker_wrapper = self ._raw_id_tracker_wrapper
2611+ assert (
2612+ raw_id_tracker_wrapper is not None
2613+ ), "self._raw_id_tracker_wrapper should not be None"
2614+ # TODO: Calling get_indexed_lookups(None) retrieves raw IDs for ALL tracked FQNs,
2615+ # including those this TBE doesn't own, and advances the shared consumer read index.
2616+ # While storage isn't deleted, advancing the index prevents re-reading, which blocks
2617+ # other TBEs from accessing their tracked raw IDs.
2618+ raw_ids_list = raw_id_tracker_wrapper .get_indexed_lookups (
2619+ features .keys (), self .emb_module .uuid
2620+ )
2621+ if raw_ids_list :
2622+ forward_args ["hash_zch_identities" ] = torch .cat (raw_ids_list )
25832623 weights = features .weights_or_none ()
25842624 if weights is not None and not torch .is_floating_point (weights ):
25852625 weights = None
@@ -2591,17 +2631,22 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
25912631 SSDTableBatchedEmbeddingBags ,
25922632 ),
25932633 ):
2634+ forward_args ["batch_size_per_feature_per_rank" ] = (
2635+ features .stride_per_key_per_rank ()
2636+ )
2637+
2638+ if len (forward_args ) == 0 :
25942639 return self .emb_module (
25952640 indices = features .values ().long (),
25962641 offsets = features .offsets ().long (),
25972642 per_sample_weights = weights ,
2598- batch_size_per_feature_per_rank = features .stride_per_key_per_rank (),
25992643 )
26002644 else :
26012645 return self .emb_module (
26022646 indices = features .values ().long (),
26032647 offsets = features .offsets ().long (),
26042648 per_sample_weights = weights ,
2649+ ** forward_args ,
26052650 )
26062651
26072652 # pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
@@ -2668,6 +2713,22 @@ def named_parameters_by_table(
26682713 for name , param in self ._param_per_table .items ():
26692714 yield name , param
26702715
2716+ def init_raw_id_tracker (
2717+ self ,
2718+ get_indexed_lookups : Callable [
2719+ [List [str ], Optional [str ]],
2720+ List [torch .Tensor ],
2721+ ],
2722+ delete : Callable [
2723+ [int ],
2724+ None ,
2725+ ],
2726+ ) -> None :
2727+ if isinstance (self ._emb_module , SplitTableBatchedEmbeddingBagsCodegen ):
2728+ self ._raw_id_tracker_wrapper = RawIdTrackerWrapper (
2729+ get_indexed_lookups , delete
2730+ )
2731+
26712732
26722733class KeyValueEmbeddingBag (BaseBatchedEmbeddingBag [torch .Tensor ], FusedOptimizerModule ):
26732734 def __init__ (
0 commit comments