Skip to content

Commit d6dbdad

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Add raw_id_tracker for tracking hash_zch_identities of MPZCH module (#3501)
Summary: This diff introduces a new `RawIdTracker` class that extends TorchRec's model delta tracking infra to capture and track raw hash identities from MCC modules during training. This is specifically required for tracking raw ids for MPZCH tables. internal This is needed to support MPZCH modules for Raw embedding streaming. Mode details : https://docs.google.com/document/d/1KEHwiXKLgXwRIdDFBYopjX3OiP3mRLM24Qkbiiu-TgE/edit?tab=t.0#bookmark=id.lhhgee2cs6ld Differential Revision: D84920167
1 parent c135256 commit d6dbdad

File tree

8 files changed

+367
-7
lines changed

8 files changed

+367
-7
lines changed

torchrec/distributed/embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1605,7 +1605,7 @@ def compute_and_output_dist(
16051605
):
16061606
embs = lookup(features)
16071607
if self.post_lookup_tracker_fn is not None:
1608-
self.post_lookup_tracker_fn(features, embs, self)
1608+
self.post_lookup_tracker_fn(features, embs, self, None)
16091609

16101610
with maybe_annotate_embedding_event(
16111611
EmbeddingEvent.OUTPUT_DIST, self._module_fqn, sharding_type

torchrec/distributed/embedding_types.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,15 @@ def __init__(
391391
self._lookups: List[nn.Module] = []
392392
self._output_dists: List[nn.Module] = []
393393
self.post_lookup_tracker_fn: Optional[
394-
Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None]
394+
Callable[
395+
[
396+
KeyedJaggedTensor,
397+
torch.Tensor,
398+
Optional[nn.Module],
399+
Optional[torch.Tensor],
400+
],
401+
None,
402+
]
395403
] = None
396404
self.post_odist_tracker_fn: Optional[Callable[..., None]] = None
397405

@@ -445,7 +453,13 @@ def train(self, mode: bool = True): # pyre-ignore[3]
445453
def register_post_lookup_tracker_fn(
446454
self,
447455
record_fn: Callable[
448-
[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None
456+
[
457+
KeyedJaggedTensor,
458+
torch.Tensor,
459+
Optional[nn.Module],
460+
Optional[torch.Tensor],
461+
],
462+
None,
449463
],
450464
) -> None:
451465
"""

torchrec/distributed/embeddingbag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1671,7 +1671,7 @@ def compute_and_output_dist(
16711671
):
16721672
embs = lookup(features)
16731673
if self.post_lookup_tracker_fn is not None:
1674-
self.post_lookup_tracker_fn(features, embs, self)
1674+
self.post_lookup_tracker_fn(features, embs, self, None)
16751675

16761676
with maybe_annotate_embedding_event(
16771677
EmbeddingEvent.OUTPUT_DIST,

torchrec/distributed/mc_modules.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,15 @@ def __init__(
238238
self._use_index_dedup = use_index_dedup
239239
self._initialize_torch_state()
240240
self.post_lookup_tracker_fn: Optional[
241-
Callable[[KeyedJaggedTensor, torch.Tensor], None]
241+
Callable[
242+
[
243+
KeyedJaggedTensor,
244+
torch.Tensor,
245+
Optional[nn.Module],
246+
Optional[torch.Tensor],
247+
],
248+
None,
249+
]
242250
] = None
243251

244252
def _initialize_torch_state(self) -> None:
@@ -756,6 +764,8 @@ def compute(
756764
if self.post_lookup_tracker_fn is not None:
757765
self.post_lookup_tracker_fn(
758766
KeyedJaggedTensor.from_jt_dict(output),
767+
torch.empty(0),
768+
None,
759769
mcm._hash_zch_identities.index_select(
760770
dim=0, index=mc_input[table].values()
761771
),
@@ -782,6 +792,8 @@ def compute(
782792
if self.post_lookup_tracker_fn is not None:
783793
self.post_lookup_tracker_fn(
784794
KeyedJaggedTensor.from_jt_dict(mc_input),
795+
torch.empty(0),
796+
None,
785797
mcm._hash_zch_identities.index_select(dim=0, index=values),
786798
)
787799

@@ -876,7 +888,15 @@ def unsharded_module_type(self) -> Type[ManagedCollisionCollection]:
876888

877889
def register_post_lookup_tracker_fn(
878890
self,
879-
record_fn: Callable[[KeyedJaggedTensor, torch.Tensor], None],
891+
record_fn: Callable[
892+
[
893+
KeyedJaggedTensor,
894+
torch.Tensor,
895+
Optional[nn.Module],
896+
Optional[torch.Tensor],
897+
],
898+
None,
899+
],
880900
) -> None:
881901
"""
882902
Register a function to be called after lookup is done. This is used for

torchrec/distributed/model_tracker/delta_store.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def append(
9191
fqn: str,
9292
ids: torch.Tensor,
9393
states: Optional[torch.Tensor],
94+
raw_ids: Optional[torch.Tensor] = None,
9495
) -> None:
9596
"""
9697
Append a batch of ids and states to the store for a specific table.
@@ -162,10 +163,11 @@ def append(
162163
fqn: str,
163164
ids: torch.Tensor,
164165
states: Optional[torch.Tensor],
166+
raw_ids: Optional[torch.Tensor] = None,
165167
) -> None:
166168
table_fqn_lookup = self.per_fqn_lookups.get(fqn, [])
167169
table_fqn_lookup.append(
168-
IndexedLookup(batch_idx=batch_idx, ids=ids, states=states)
170+
IndexedLookup(batch_idx=batch_idx, ids=ids, states=states, raw_ids=raw_ids)
169171
)
170172
self.per_fqn_lookups[fqn] = table_fqn_lookup
171173

@@ -224,6 +226,20 @@ def compact(self, start_idx: int, end_idx: int) -> None:
224226
)
225227
self.per_fqn_lookups = new_per_fqn_lookups
226228

229+
def get_indexed_lookups(
230+
self, start_idx: int, end_idx: int
231+
) -> Dict[str, List[IndexedLookup]]:
232+
r"""
233+
Return all unique/delta ids per table from the Delta Store.
234+
"""
235+
per_fqn_lookups: Dict[str, List[IndexedLookup]] = {}
236+
for table_fqn, lookups in self.per_fqn_lookups.items():
237+
indexices = [h.batch_idx for h in lookups]
238+
index_l = bisect_left(indexices, start_idx)
239+
index_r = bisect_left(indexices, end_idx)
240+
per_fqn_lookups[table_fqn] = lookups[index_l:index_r]
241+
return per_fqn_lookups
242+
227243
def get_unique(self, from_idx: int = 0) -> Dict[str, UniqueRows]:
228244
r"""
229245
Return all unique/delta ids per table from the Delta Store.

torchrec/distributed/model_tracker/model_delta_tracker.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def record_lookup(
8484
kjt: KeyedJaggedTensor,
8585
states: torch.Tensor,
8686
emb_module: Optional[nn.Module] = None,
87+
raw_ids: Optional[torch.Tensor] = None,
8788
) -> None:
8889
"""
8990
Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states.
@@ -131,6 +132,13 @@ def clear(self, consumer: Optional[str] = None) -> None:
131132
"""
132133
pass
133134

135+
@abstractmethod
136+
def step(self) -> None:
137+
"""
138+
Advance the batch index for all consumers.
139+
"""
140+
pass
141+
134142

135143
class ModelDeltaTrackerTrec(ModelDeltaTracker):
136144
r"""
@@ -244,6 +252,7 @@ def record_lookup(
244252
kjt: KeyedJaggedTensor,
245253
states: torch.Tensor,
246254
emb_module: Optional[nn.Module] = None,
255+
raw_ids: Optional[torch.Tensor] = None,
247256
) -> None:
248257
"""
249258
Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states.

0 commit comments

Comments
 (0)