Skip to content

Commit dc9c7af

Browse files
maliafzalmeta-codesync[bot]
authored andcommitted
Post lookup tracker function for MCC modules (#3500)
Summary: Pull Request resolved: #3500 Adding post lookup tracker function within MMC module to allow tracking of hash_zch_identities with delta tracker. internal This is needed to support MPZCH modules for Raw embedding streaming. Mode details : https://docs.google.com/document/d/1KEHwiXKLgXwRIdDFBYopjX3OiP3mRLM24Qkbiiu-TgE/edit?tab=t.0#bookmark=id.lhhgee2cs6ld Differential Revision: D84920121
1 parent fc1f3ed commit dc9c7af

File tree

2 files changed

+57
-3
lines changed

2 files changed

+57
-3
lines changed

torchrec/distributed/mc_modules.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,17 @@
1313
import math
1414
from collections import defaultdict, OrderedDict
1515
from dataclasses import dataclass
16-
from typing import Any, DefaultDict, Dict, Iterator, List, Optional, Type, Union
16+
from typing import (
17+
Any,
18+
Callable,
19+
DefaultDict,
20+
Dict,
21+
Iterator,
22+
List,
23+
Optional,
24+
Type,
25+
Union,
26+
)
1727

1828
import torch
1929
import torch.distributed as dist
@@ -58,6 +68,7 @@
5868
ShardingType,
5969
)
6070
from torchrec.distributed.utils import append_prefix
71+
from torchrec.modules.embedding_configs import BaseEmbeddingConfig
6172
from torchrec.modules.mc_modules import ManagedCollisionCollection
6273
from torchrec.modules.utils import construct_jagged_tensors
6374
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
@@ -215,6 +226,9 @@ def __init__(
215226

216227
self._feature_to_table: Dict[str, str] = module._feature_to_table
217228
self._table_to_features: Dict[str, List[str]] = module._table_to_features
229+
self._table_name_to_config: Dict[str, BaseEmbeddingConfig] = (
230+
module._table_name_to_config
231+
)
218232
self._has_uninitialized_input_dists: bool = True
219233
self._input_dists: List[nn.Module] = []
220234
self._managed_collision_modules = nn.ModuleDict()
@@ -223,6 +237,9 @@ def __init__(
223237
self._create_output_dists()
224238
self._use_index_dedup = use_index_dedup
225239
self._initialize_torch_state()
240+
self.post_lookup_tracker_fn: Optional[
241+
Callable[[KeyedJaggedTensor, torch.Tensor], None]
242+
] = None
226243

227244
def _initialize_torch_state(self) -> None:
228245
self._model_parallel_mc_buffer_name_to_sharded_tensor = OrderedDict()
@@ -732,6 +749,17 @@ def compute(
732749
mc_input = mcm.remap(mc_input)
733750
mc_input = self.global_to_local_index(mc_input)
734751
output.update(mc_input)
752+
if hasattr(
753+
mcm,
754+
"_hash_zch_identities",
755+
):
756+
if self.post_lookup_tracker_fn is not None:
757+
self.post_lookup_tracker_fn(
758+
KeyedJaggedTensor.from_jt_dict(mc_input),
759+
mcm._hash_zch_identities.index_select(
760+
dim=0, index=mc_input[table].values()
761+
),
762+
)
735763
values = torch.cat([jt.values() for jt in output.values()])
736764
else:
737765
table: str = tables[0]
@@ -750,6 +778,12 @@ def compute(
750778
mc_input = mcm.remap(mc_input)
751779
mc_input = self.global_to_local_index(mc_input)
752780
values = mc_input[table].values()
781+
if hasattr(mcm, "_hash_zch_identities"):
782+
if self.post_lookup_tracker_fn is not None:
783+
self.post_lookup_tracker_fn(
784+
KeyedJaggedTensor.from_jt_dict(mc_input),
785+
mcm._hash_zch_identities.index_select(dim=0, index=values),
786+
)
753787

754788
remapped_kjts.append(
755789
KeyedJaggedTensor(
@@ -840,6 +874,24 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
840874
def unsharded_module_type(self) -> Type[ManagedCollisionCollection]:
841875
return ManagedCollisionCollection
842876

877+
def register_post_lookup_tracker_fn(
878+
self,
879+
record_fn: Callable[[KeyedJaggedTensor, torch.Tensor], None],
880+
) -> None:
881+
"""
882+
Register a function to be called after lookup is done. This is used for
883+
tracking the lookup results and optimizer states.
884+
885+
Args:
886+
record_fn (Callable[[KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done.
887+
888+
"""
889+
if self.post_lookup_tracker_fn is not None:
890+
logger.warning(
891+
"[ModelDeltaTracker] Custom record function already defined, overriding with new callable"
892+
)
893+
self.post_lookup_tracker_fn = record_fn
894+
843895

844896
class ManagedCollisionCollectionSharder(
845897
BaseEmbeddingSharder[ManagedCollisionCollection]

torchrec/modules/mc_modules.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,9 +357,11 @@ def __init__(
357357
len(features) for features in self._table_to_features.values()
358358
]
359359

360-
table_to_config = {config.name: config for config in embedding_configs}
360+
self._table_name_to_config: Dict[str, BaseEmbeddingConfig] = {
361+
config.name: config for config in embedding_configs
362+
}
361363

362-
for name, config in table_to_config.items():
364+
for name, config in self._table_name_to_config.items():
363365
if name not in managed_collision_modules:
364366
raise ValueError(
365367
f"Table {name} is not present in managed_collision_modules"

0 commit comments

Comments
 (0)