1313import math
1414from collections import defaultdict , OrderedDict
1515from dataclasses import dataclass
16- from typing import Any , DefaultDict , Dict , Iterator , List , Optional , Type , Union
16+ from typing import (
17+ Any ,
18+ Callable ,
19+ DefaultDict ,
20+ Dict ,
21+ Iterator ,
22+ List ,
23+ Optional ,
24+ Type ,
25+ Union ,
26+ )
1727
1828import torch
1929import torch .distributed as dist
5868 ShardingType ,
5969)
6070from torchrec .distributed .utils import append_prefix
71+ from torchrec .modules .embedding_configs import BaseEmbeddingConfig
6172from torchrec .modules .mc_modules import ManagedCollisionCollection
6273from torchrec .modules .utils import construct_jagged_tensors
6374from torchrec .sparse .jagged_tensor import JaggedTensor , KeyedJaggedTensor
@@ -215,6 +226,9 @@ def __init__(
215226
216227 self ._feature_to_table : Dict [str , str ] = module ._feature_to_table
217228 self ._table_to_features : Dict [str , List [str ]] = module ._table_to_features
229+ self ._table_name_to_config : Dict [str , BaseEmbeddingConfig ] = (
230+ module ._table_name_to_config
231+ )
218232 self ._has_uninitialized_input_dists : bool = True
219233 self ._input_dists : List [nn .Module ] = []
220234 self ._managed_collision_modules = nn .ModuleDict ()
@@ -223,6 +237,9 @@ def __init__(
223237 self ._create_output_dists ()
224238 self ._use_index_dedup = use_index_dedup
225239 self ._initialize_torch_state ()
240+ self .post_lookup_tracker_fn : Optional [
241+ Callable [[KeyedJaggedTensor , torch .Tensor ], None ]
242+ ] = None
226243
227244 def _initialize_torch_state (self ) -> None :
228245 self ._model_parallel_mc_buffer_name_to_sharded_tensor = OrderedDict ()
@@ -732,6 +749,17 @@ def compute(
732749 mc_input = mcm .remap (mc_input )
733750 mc_input = self .global_to_local_index (mc_input )
734751 output .update (mc_input )
752+ if hasattr (
753+ mcm ,
754+ "_hash_zch_identities" ,
755+ ):
756+ if self .post_lookup_tracker_fn is not None :
757+ self .post_lookup_tracker_fn (
758+ KeyedJaggedTensor .from_jt_dict (mc_input ),
759+ mcm ._hash_zch_identities .index_select (
760+ dim = 0 , index = mc_input [table ].values ()
761+ ),
762+ )
735763 values = torch .cat ([jt .values () for jt in output .values ()])
736764 else :
737765 table : str = tables [0 ]
@@ -750,6 +778,12 @@ def compute(
750778 mc_input = mcm .remap (mc_input )
751779 mc_input = self .global_to_local_index (mc_input )
752780 values = mc_input [table ].values ()
781+ if hasattr (mcm , "_hash_zch_identities" ):
782+ if self .post_lookup_tracker_fn is not None :
783+ self .post_lookup_tracker_fn (
784+ KeyedJaggedTensor .from_jt_dict (mc_input ),
785+ mcm ._hash_zch_identities .index_select (dim = 0 , index = values ),
786+ )
753787
754788 remapped_kjts .append (
755789 KeyedJaggedTensor (
@@ -840,6 +874,24 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
840874 def unsharded_module_type (self ) -> Type [ManagedCollisionCollection ]:
841875 return ManagedCollisionCollection
842876
877+ def register_post_lookup_tracker_fn (
878+ self ,
879+ record_fn : Callable [[KeyedJaggedTensor , torch .Tensor ], None ],
880+ ) -> None :
881+ """
882+ Register a function to be called after lookup is done. This is used for
883+ tracking the lookup results and optimizer states.
884+
885+ Args:
886+ record_fn (Callable[[KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done.
887+
888+ """
889+ if self .post_lookup_tracker_fn is not None :
890+ logger .warning (
891+ "[ModelDeltaTracker] Custom record function already defined, overriding with new callable"
892+ )
893+ self .post_lookup_tracker_fn = record_fn
894+
843895
844896class ManagedCollisionCollectionSharder (
845897 BaseEmbeddingSharder [ManagedCollisionCollection ]
0 commit comments