Skip to content

Commit 7fff1b3

Browse files
aliafzalmeta-codesync[bot]
authored andcommitted
Enabling DMP to init raw_id_tracker if configured (#3502)
Summary: Pull Request resolved: #3502 This diff introduces init_raw_id_tracker to initialize RawIdTracker if enabled through ModelTrackerConfig Key Changes ----------- 1. **Added `TrackerType` enum** in [`**types.py**`](command:code-compose.open?%5B%22%2Fdata%2Fusers%2Fmaliafzal%2Ffbsource%2Ffbcode%2Ftorchrec%2Fdistributed%2Fmodel_tracker%2Ftypes.py%22%2Cnull%5D "/data/users/maliafzal/fbsource/fbcode/torchrec/distributed/model_tracker/types.py") * Defines three tracker types: `NONE`, `TREC` (for EC/EBC), and `RAW_ID` (for MPZCH) * Extended `ModelTrackerConfig` dataclass with `tracker_type` field (defaults to `TrackerType.NONE`) 2. **Enhanced DMP initialization** in [`**model_parallel.py**`](command:code-compose.open?%5B%22%2Fdata%2Fusers%2Fmaliafzal%2Ffbsource%2Ffbcode%2Ftorchrec%2Fdistributed%2Fmodel_parallel.py%22%2Cnull%5D "/data/users/maliafzal/fbsource/fbcode/torchrec/distributed/model_parallel.py") * Added `init_raw_id_tracker()` method to create `RawIdTracker` instances * Modified constructor to conditionally initialize trackers based on `tracker_type` configuration internal This is needed to support MPZCH modules for Raw embedding streaming. Mode details : https://docs.google.com/document/d/1KEHwiXKLgXwRIdDFBYopjX3OiP3mRLM24Qkbiiu-TgE/edit?tab=t.0#bookmark=id.lhhgee2cs6ld Reviewed By: chouxi Differential Revision: D84920233 fbshipit-source-id: 11fa8826d8166d8b2f4027bc3fa0081e4428fcb0
1 parent 8f0dab7 commit 7fff1b3

File tree

4 files changed

+161
-113
lines changed

4 files changed

+161
-113
lines changed

torchrec/distributed/model_parallel.py

Lines changed: 55 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,18 @@
2929
from torch.nn.modules.module import _IncompatibleKeys
3030
from torch.nn.parallel import DistributedDataParallel
3131
from torchrec.distributed.comm import get_local_size
32-
from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTrackerTrec
33-
from torchrec.distributed.model_tracker.types import ModelTrackerConfig, UniqueRows
32+
from torchrec.distributed.model_tracker.model_delta_tracker import (
33+
ModelDeltaTracker,
34+
ModelDeltaTrackerTrec,
35+
)
36+
from torchrec.distributed.model_tracker.trackers.raw_id_tracker import RawIdTracker
37+
from torchrec.distributed.model_tracker.types import (
38+
DeltaTrackerConfig,
39+
ModelTrackerConfigs,
40+
RawIdTrackerConfig,
41+
Trackers,
42+
UniqueRows,
43+
)
3444

3545
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
3646
from torchrec.distributed.sharding_plan import get_default_sharders
@@ -53,6 +63,7 @@
5363
none_throws,
5464
sharded_model_copy,
5565
)
66+
5667
from torchrec.optim.fused import FusedOptimizerModule
5768
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
5869

@@ -240,7 +251,7 @@ def __init__(
240251
init_data_parallel: bool = True,
241252
init_parameters: bool = True,
242253
data_parallel_wrapper: Optional[DataParallelWrapper] = None,
243-
model_tracker_config: Optional[ModelTrackerConfig] = None,
254+
model_tracker_configs: Optional[ModelTrackerConfigs] = None,
244255
) -> None:
245256
super().__init__()
246257
torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}")
@@ -294,11 +305,18 @@ def __init__(
294305
if init_data_parallel:
295306
self.init_data_parallel()
296307

297-
self.model_delta_tracker: Optional[ModelDeltaTrackerTrec] = (
298-
self._init_delta_tracker(model_tracker_config, self._dmp_wrapped_module)
299-
if model_tracker_config is not None
300-
else None
301-
)
308+
self.model_trackers: Dict[str, ModelDeltaTracker] = {}
309+
310+
if (
311+
model_tracker_configs is not None
312+
and model_tracker_configs.raw_id_tracker_config is not None
313+
):
314+
self.model_trackers[Trackers.RAW_ID_TRACKER.name] = (
315+
self._init_raw_id_tracker(
316+
model_tracker_configs.raw_id_tracker_config,
317+
self._dmp_wrapped_module,
318+
)
319+
)
302320

303321
@property
304322
def module(self) -> nn.Module:
@@ -321,14 +339,14 @@ def module(self, value: nn.Module) -> None:
321339

322340
# pyre-ignore [2, 3]
323341
def forward(self, *args, **kwargs) -> Any:
324-
if self.model_delta_tracker is not None:
342+
for tracker in self.model_trackers.values():
325343
# The step() call advances the internal batch counter so that subsequent ID tracking and delta
326344
# retrieval operations can be properly organized by batch boundaries.
327345

328346
# Context: ModelDeltaTracker tracks unique embedding IDs (and optionally embeddings / states)
329347
# useful for calculating topk rows for checkpointing and for updating fresh embedding weights
330348
# between predictors and trainers in online training scenarios.
331-
self.model_delta_tracker.step()
349+
tracker.step()
332350

333351
# Model forward for DMP wrapped model
334352
return self._dmp_wrapped_module(*args, **kwargs)
@@ -369,16 +387,25 @@ def _init_dmp(self, module: nn.Module) -> nn.Module:
369387
return self._shard_modules_impl(module)
370388

371389
def _init_delta_tracker(
372-
self, model_tracker_config: ModelTrackerConfig, module: nn.Module
373-
) -> ModelDeltaTrackerTrec:
390+
self, delta_tracker_config: DeltaTrackerConfig, module: nn.Module
391+
) -> ModelDeltaTracker:
374392
# Init delta tracker if config is provided
375393
return ModelDeltaTrackerTrec(
376394
model=module,
377-
consumers=model_tracker_config.consumers,
378-
delete_on_read=model_tracker_config.delete_on_read,
379-
auto_compact=model_tracker_config.auto_compact,
380-
mode=model_tracker_config.tracking_mode,
381-
fqns_to_skip=model_tracker_config.fqns_to_skip,
395+
consumers=delta_tracker_config.consumers,
396+
delete_on_read=delta_tracker_config.delete_on_read,
397+
auto_compact=delta_tracker_config.auto_compact,
398+
mode=delta_tracker_config.tracking_mode,
399+
fqns_to_skip=delta_tracker_config.fqns_to_skip,
400+
)
401+
402+
def _init_raw_id_tracker(
403+
self, raw_id_tracker_config: RawIdTrackerConfig, module: nn.Module
404+
) -> RawIdTracker:
405+
return RawIdTracker(
406+
model=module,
407+
delete_on_read=raw_id_tracker_config.delete_on_read,
408+
fqns_to_skip=raw_id_tracker_config.fqns_to_skip,
382409
)
383410

384411
def _init_optim(self, module: nn.Module) -> CombinedOptimizer:
@@ -458,36 +485,25 @@ def init_parameters(module: nn.Module) -> None:
458485
module.apply(init_parameters)
459486

460487
def init_torchrec_delta_tracker(
461-
self, model_tracker_config: ModelTrackerConfig
462-
) -> ModelDeltaTrackerTrec:
488+
self, delta_tracker_config: DeltaTrackerConfig
489+
) -> ModelDeltaTracker:
463490
"""
464491
Initializes the model delta tracker if it doesn't exists.
465492
"""
466-
if self.model_delta_tracker is None:
467-
self.model_delta_tracker = self._init_delta_tracker(
468-
model_tracker_config, self._dmp_wrapped_module
493+
if Trackers.DELTA_TRACKER.name not in self.model_trackers:
494+
self.model_trackers[Trackers.DELTA_TRACKER.name] = self._init_delta_tracker(
495+
delta_tracker_config, self._dmp_wrapped_module
469496
)
470497

471-
return none_throws(self.model_delta_tracker)
472-
473-
def get_model_tracker(self) -> ModelDeltaTrackerTrec:
474-
"""
475-
Returns the model tracker if it exists.
476-
"""
477-
478-
assert (
479-
self.model_delta_tracker is not None
480-
), "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init."
481-
return self.model_delta_tracker
498+
return self.model_trackers[Trackers.DELTA_TRACKER.name]
482499

483-
def get_unique(self, consumer: Optional[str] = None) -> Dict[str, UniqueRows]:
500+
def get_delta_tracker(self) -> Optional[ModelDeltaTracker]:
484501
"""
485-
Returns the delta rows for the given consumer.
502+
Returns the delta tracker if it exists.
486503
"""
487-
assert (
488-
self.model_delta_tracker is not None
489-
), "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init."
490-
return self.model_delta_tracker.get_unique(consumer)
504+
if Trackers.DELTA_TRACKER.name in self.model_trackers:
505+
return self.model_trackers[Trackers.DELTA_TRACKER.name]
506+
return None
491507

492508
def sparse_grad_parameter_names(
493509
self, destination: Optional[List[str]] = None, prefix: str = ""

torchrec/distributed/model_tracker/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
)
2929
from torchrec.distributed.model_tracker.types import (
3030
IndexedLookup, # noqa
31-
ModelTrackerConfig, # noqa
31+
ModelTrackerConfigs, # noqa
32+
Trackers, # noqa
3233
TrackingMode, # noqa
3334
UniqueRows, # noqa
3435
UpdateMode, # noqa

0 commit comments

Comments
 (0)