Skip to content

Commit febf3f0

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Enabling DMP to init raw_id_tracker if configured (#3502)
Summary: This diff introduces init_raw_id_tracker to initialize RawIdTracker if enabled through ModelTrackerConfig Key Changes ----------- 1. **Added `TrackerType` enum** in [`**types.py**`](command:code-compose.open?%5B%22%2Fdata%2Fusers%2Fmaliafzal%2Ffbsource%2Ffbcode%2Ftorchrec%2Fdistributed%2Fmodel_tracker%2Ftypes.py%22%2Cnull%5D "/data/users/maliafzal/fbsource/fbcode/torchrec/distributed/model_tracker/types.py") * Defines three tracker types: `NONE`, `TREC` (for EC/EBC), and `RAW_ID` (for MPZCH) * Extended `ModelTrackerConfig` dataclass with `tracker_type` field (defaults to `TrackerType.NONE`) 2. **Enhanced DMP initialization** in [`**model_parallel.py**`](command:code-compose.open?%5B%22%2Fdata%2Fusers%2Fmaliafzal%2Ffbsource%2Ffbcode%2Ftorchrec%2Fdistributed%2Fmodel_parallel.py%22%2Cnull%5D "/data/users/maliafzal/fbsource/fbcode/torchrec/distributed/model_parallel.py") * Added `init_raw_id_tracker()` method to create `RawIdTracker` instances * Modified constructor to conditionally initialize trackers based on `tracker_type` configuration internal This is needed to support MPZCH modules for Raw embedding streaming. Mode details : https://docs.google.com/document/d/1KEHwiXKLgXwRIdDFBYopjX3OiP3mRLM24Qkbiiu-TgE/edit?tab=t.0#bookmark=id.lhhgee2cs6ld Differential Revision: D84920233
1 parent c13c3fd commit febf3f0

File tree

2 files changed

+48
-6
lines changed

2 files changed

+48
-6
lines changed

torchrec/distributed/model_parallel.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,16 @@
2929
from torch.nn.modules.module import _IncompatibleKeys
3030
from torch.nn.parallel import DistributedDataParallel
3131
from torchrec.distributed.comm import get_local_size
32-
from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTrackerTrec
33-
from torchrec.distributed.model_tracker.types import ModelTrackerConfig, UniqueRows
32+
from torchrec.distributed.model_tracker.model_delta_tracker import (
33+
ModelDeltaTracker,
34+
ModelDeltaTrackerTrec,
35+
)
36+
from torchrec.distributed.model_tracker.trackers.raw_id_tracker import RawIdTracker
37+
from torchrec.distributed.model_tracker.types import (
38+
ModelTrackerConfig,
39+
TrackerType,
40+
UniqueRows,
41+
)
3442

3543
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
3644
from torchrec.distributed.sharding_plan import get_default_sharders
@@ -53,6 +61,7 @@
5361
none_throws,
5462
sharded_model_copy,
5563
)
64+
5665
from torchrec.optim.fused import FusedOptimizerModule
5766
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
5867

@@ -294,12 +303,21 @@ def __init__(
294303
if init_data_parallel:
295304
self.init_data_parallel()
296305

297-
self.model_delta_tracker: Optional[ModelDeltaTrackerTrec] = (
306+
self.model_delta_tracker: Optional[ModelDeltaTracker] = (
298307
self._init_delta_tracker(model_tracker_config, self._dmp_wrapped_module)
299308
if model_tracker_config is not None
309+
and model_tracker_config.tracker_type == TrackerType.TREC
300310
else None
301311
)
302312

313+
if (
314+
model_tracker_config is not None
315+
and model_tracker_config.tracker_type == TrackerType.RAW_ID
316+
):
317+
self.raw_id_tracker: Optional[RawIdTracker] = self.init_raw_id_tracker(
318+
model_tracker_config, self._dmp_wrapped_module
319+
)
320+
303321
@property
304322
def module(self) -> nn.Module:
305323
"""
@@ -370,7 +388,7 @@ def _init_dmp(self, module: nn.Module) -> nn.Module:
370388

371389
def _init_delta_tracker(
372390
self, model_tracker_config: ModelTrackerConfig, module: nn.Module
373-
) -> ModelDeltaTrackerTrec:
391+
) -> ModelDeltaTracker:
374392
# Init delta tracker if config is provided
375393
return ModelDeltaTrackerTrec(
376394
model=module,
@@ -381,6 +399,15 @@ def _init_delta_tracker(
381399
fqns_to_skip=model_tracker_config.fqns_to_skip,
382400
)
383401

402+
def init_raw_id_tracker(
403+
self, model_tracker_config: ModelTrackerConfig, module: nn.Module
404+
) -> Optional[RawIdTracker]:
405+
return RawIdTracker(
406+
model=module,
407+
delete_on_read=model_tracker_config.delete_on_read,
408+
fqns_to_skip=model_tracker_config.fqns_to_skip,
409+
)
410+
384411
def _init_optim(self, module: nn.Module) -> CombinedOptimizer:
385412
# pyre-ignore [6]
386413
return CombinedOptimizer(self._fused_optim_impl(module, []))
@@ -459,7 +486,7 @@ def init_parameters(module: nn.Module) -> None:
459486

460487
def init_torchrec_delta_tracker(
461488
self, model_tracker_config: ModelTrackerConfig
462-
) -> ModelDeltaTrackerTrec:
489+
) -> ModelDeltaTracker:
463490
"""
464491
Initializes the model delta tracker if it doesn't exists.
465492
"""
@@ -470,7 +497,7 @@ def init_torchrec_delta_tracker(
470497

471498
return none_throws(self.model_delta_tracker)
472499

473-
def get_model_tracker(self) -> ModelDeltaTrackerTrec:
500+
def get_model_tracker(self) -> ModelDeltaTracker:
474501
"""
475502
Returns the model tracker if it exists.
476503
"""

torchrec/distributed/model_tracker/types.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,20 @@ class UpdateMode(Enum):
7575
LAST = "last"
7676

7777

78+
class TrackerType(Enum):
79+
r"""
80+
To identify the type of tracker.
81+
82+
Enums:
83+
TREC: Used for Trec EC/EBC.
84+
RAW_ID: Used for MPZCH.
85+
"""
86+
87+
NONE = "none"
88+
TREC = "trec"
89+
RAW_ID = "raw_id"
90+
91+
7892
@dataclass
7993
class ModelTrackerConfig:
8094
r"""
@@ -92,3 +106,4 @@ class ModelTrackerConfig:
92106
delete_on_read: bool = True
93107
auto_compact: bool = False
94108
fqns_to_skip: List[str] = field(default_factory=list)
109+
tracker_type: TrackerType = TrackerType.NONE

0 commit comments

Comments
 (0)