Skip to content

Commit cce57ee

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Enabling DMP to init raw_id_tracker if configured (#3502)
Summary: This diff introduces init_raw_id_tracker to initialize RawIdTracker if enabled through ModelTrackerConfig Key Changes ----------- 1. **Added `TrackerType` enum** in [`**types.py**`](command:code-compose.open?%5B%22%2Fdata%2Fusers%2Fmaliafzal%2Ffbsource%2Ffbcode%2Ftorchrec%2Fdistributed%2Fmodel_tracker%2Ftypes.py%22%2Cnull%5D "/data/users/maliafzal/fbsource/fbcode/torchrec/distributed/model_tracker/types.py") * Defines three tracker types: `NONE`, `TREC` (for EC/EBC), and `RAW_ID` (for MPZCH) * Extended `ModelTrackerConfig` dataclass with `tracker_type` field (defaults to `TrackerType.NONE`) 2. **Enhanced DMP initialization** in [`**model_parallel.py**`](command:code-compose.open?%5B%22%2Fdata%2Fusers%2Fmaliafzal%2Ffbsource%2Ffbcode%2Ftorchrec%2Fdistributed%2Fmodel_parallel.py%22%2Cnull%5D "/data/users/maliafzal/fbsource/fbcode/torchrec/distributed/model_parallel.py") * Added `init_raw_id_tracker()` method to create `RawIdTracker` instances * Modified constructor to conditionally initialize trackers based on `tracker_type` configuration internal This is needed to support MPZCH modules for Raw embedding streaming. Mode details : https://docs.google.com/document/d/1KEHwiXKLgXwRIdDFBYopjX3OiP3mRLM24Qkbiiu-TgE/edit?tab=t.0#bookmark=id.lhhgee2cs6ld Reviewed By: chouxi Differential Revision: D84920233
1 parent 1056a70 commit cce57ee

File tree

3 files changed

+101
-15
lines changed

3 files changed

+101
-15
lines changed

torchrec/distributed/model_parallel.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,16 @@
2929
from torch.nn.modules.module import _IncompatibleKeys
3030
from torch.nn.parallel import DistributedDataParallel
3131
from torchrec.distributed.comm import get_local_size
32-
from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTrackerTrec
33-
from torchrec.distributed.model_tracker.types import ModelTrackerConfig, UniqueRows
32+
from torchrec.distributed.model_tracker.model_delta_tracker import (
33+
ModelDeltaTracker,
34+
ModelDeltaTrackerTrec,
35+
)
36+
from torchrec.distributed.model_tracker.trackers.raw_id_tracker import RawIdTracker
37+
from torchrec.distributed.model_tracker.types import (
38+
ModelTrackerConfig,
39+
TrackerType,
40+
UniqueRows,
41+
)
3442

3543
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
3644
from torchrec.distributed.sharding_plan import get_default_sharders
@@ -53,6 +61,7 @@
5361
none_throws,
5462
sharded_model_copy,
5563
)
64+
5665
from torchrec.optim.fused import FusedOptimizerModule
5766
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
5867

@@ -294,12 +303,21 @@ def __init__(
294303
if init_data_parallel:
295304
self.init_data_parallel()
296305

297-
self.model_delta_tracker: Optional[ModelDeltaTrackerTrec] = (
306+
self.model_delta_tracker: Optional[ModelDeltaTracker] = (
298307
self._init_delta_tracker(model_tracker_config, self._dmp_wrapped_module)
299308
if model_tracker_config is not None
309+
and model_tracker_config.tracker_type == TrackerType.TREC
300310
else None
301311
)
302312

313+
if (
314+
model_tracker_config is not None
315+
and model_tracker_config.tracker_type == TrackerType.RAW_ID
316+
):
317+
self.raw_id_tracker: Optional[RawIdTracker] = self.init_raw_id_tracker(
318+
model_tracker_config, self._dmp_wrapped_module
319+
)
320+
303321
@property
304322
def module(self) -> nn.Module:
305323
"""
@@ -370,7 +388,7 @@ def _init_dmp(self, module: nn.Module) -> nn.Module:
370388

371389
def _init_delta_tracker(
372390
self, model_tracker_config: ModelTrackerConfig, module: nn.Module
373-
) -> ModelDeltaTrackerTrec:
391+
) -> ModelDeltaTracker:
374392
# Init delta tracker if config is provided
375393
return ModelDeltaTrackerTrec(
376394
model=module,
@@ -381,6 +399,15 @@ def _init_delta_tracker(
381399
fqns_to_skip=model_tracker_config.fqns_to_skip,
382400
)
383401

402+
def init_raw_id_tracker(
403+
self, model_tracker_config: ModelTrackerConfig, module: nn.Module
404+
) -> Optional[RawIdTracker]:
405+
return RawIdTracker(
406+
model=module,
407+
delete_on_read=model_tracker_config.delete_on_read,
408+
fqns_to_skip=model_tracker_config.fqns_to_skip,
409+
)
410+
384411
def _init_optim(self, module: nn.Module) -> CombinedOptimizer:
385412
# pyre-ignore [6]
386413
return CombinedOptimizer(self._fused_optim_impl(module, []))
@@ -459,7 +486,7 @@ def init_parameters(module: nn.Module) -> None:
459486

460487
def init_torchrec_delta_tracker(
461488
self, model_tracker_config: ModelTrackerConfig
462-
) -> ModelDeltaTrackerTrec:
489+
) -> ModelDeltaTracker:
463490
"""
464491
Initializes the model delta tracker if it doesn't exists.
465492
"""
@@ -470,7 +497,7 @@ def init_torchrec_delta_tracker(
470497

471498
return none_throws(self.model_delta_tracker)
472499

473-
def get_model_tracker(self) -> ModelDeltaTrackerTrec:
500+
def get_model_tracker(self) -> ModelDeltaTracker:
474501
"""
475502
Returns the model tracker if it exists.
476503
"""

torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,18 @@
2424
from torchrec.distributed.embedding import EmbeddingCollectionSharder
2525
from torchrec.distributed.embedding_types import ModuleSharder, ShardingType
2626
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
27+
from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTrackerTrec
2728
from torchrec.distributed.model_tracker.tests.utils import (
2829
EmbeddingTableProps,
2930
generate_planner_constraints,
3031
TestEBCModel,
3132
TestECModel,
3233
)
33-
from torchrec.distributed.model_tracker.types import ModelTrackerConfig, TrackingMode
34+
from torchrec.distributed.model_tracker.types import (
35+
ModelTrackerConfig,
36+
TrackerType,
37+
TrackingMode,
38+
)
3439

3540
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
3641
from torchrec.distributed.test_utils.multi_process import (
@@ -228,7 +233,9 @@ def __init__(self, methodName="runTest") -> None:
228233
sharding=ShardingType.ROW_WISE,
229234
),
230235
],
231-
model_tracker_config=ModelTrackerConfig(),
236+
model_tracker_config=ModelTrackerConfig(
237+
tracker_type=TrackerType.TREC
238+
),
232239
),
233240
FqnToFeatureNamesOutputTestParams(
234241
expected_fqn_to_feature_names={
@@ -263,7 +270,9 @@ def __init__(self, methodName="runTest") -> None:
263270
sharding=ShardingType.ROW_WISE,
264271
),
265272
],
266-
model_tracker_config=ModelTrackerConfig(),
273+
model_tracker_config=ModelTrackerConfig(
274+
tracker_type=TrackerType.TREC
275+
),
267276
),
268277
FqnToFeatureNamesOutputTestParams(
269278
expected_fqn_to_feature_names={
@@ -296,7 +305,9 @@ def __init__(self, methodName="runTest") -> None:
296305
sharding=ShardingType.ROW_WISE,
297306
),
298307
],
299-
model_tracker_config=ModelTrackerConfig(),
308+
model_tracker_config=ModelTrackerConfig(
309+
tracker_type=TrackerType.TREC,
310+
),
300311
),
301312
FqnToFeatureNamesOutputTestParams(
302313
expected_fqn_to_feature_names={
@@ -332,7 +343,8 @@ def __init__(self, methodName="runTest") -> None:
332343
),
333344
],
334345
model_tracker_config=ModelTrackerConfig(
335-
fqns_to_skip=["sparse_table_1"]
346+
fqns_to_skip=["sparse_table_1"],
347+
tracker_type=TrackerType.TREC,
336348
),
337349
),
338350
FqnToFeatureNamesOutputTestParams(
@@ -368,7 +380,8 @@ def __init__(self, methodName="runTest") -> None:
368380
),
369381
],
370382
model_tracker_config=ModelTrackerConfig(
371-
fqns_to_skip=["embedding_bags"]
383+
fqns_to_skip=["embedding_bags"],
384+
tracker_type=TrackerType.TREC,
372385
),
373386
),
374387
FqnToFeatureNamesOutputTestParams(
@@ -399,7 +412,10 @@ def __init__(self, methodName="runTest") -> None:
399412
sharding=ShardingType.ROW_WISE,
400413
),
401414
],
402-
model_tracker_config=ModelTrackerConfig(fqns_to_skip=["ec"]),
415+
model_tracker_config=ModelTrackerConfig(
416+
fqns_to_skip=["ec"],
417+
tracker_type=TrackerType.TREC,
418+
),
403419
),
404420
FqnToFeatureNamesOutputTestParams(
405421
expected_fqn_to_feature_names={},
@@ -440,7 +456,9 @@ def test_fqn_to_feature_names(
440456
sharding=ShardingType.ROW_WISE,
441457
),
442458
],
443-
model_tracker_config=ModelTrackerConfig(),
459+
model_tracker_config=ModelTrackerConfig(
460+
tracker_type=TrackerType.TREC,
461+
),
444462
),
445463
TrackerNotInitOutputTestParams(
446464
dmp_tracker_atter="get_model_tracker",
@@ -461,7 +479,9 @@ def test_fqn_to_feature_names(
461479
sharding=ShardingType.ROW_WISE,
462480
),
463481
],
464-
model_tracker_config=ModelTrackerConfig(),
482+
model_tracker_config=ModelTrackerConfig(
483+
tracker_type=TrackerType.TREC,
484+
),
465485
),
466486
TrackerNotInitOutputTestParams(
467487
dmp_tracker_atter="get_unique",
@@ -514,6 +534,7 @@ def test_tracker_not_initialized(
514534
model_tracker_config=ModelTrackerConfig(
515535
tracking_mode=TrackingMode.ID_ONLY,
516536
delete_on_read=True,
537+
tracker_type=TrackerType.TREC,
517538
),
518539
model_inputs=[
519540
ModelInput(
@@ -563,6 +584,7 @@ def test_tracker_not_initialized(
563584
model_tracker_config=ModelTrackerConfig(
564585
tracking_mode=TrackingMode.ID_ONLY,
565586
delete_on_read=True,
587+
tracker_type=TrackerType.TREC,
566588
),
567589
model_inputs=[
568590
ModelInput(
@@ -610,6 +632,7 @@ def test_tracker_not_initialized(
610632
model_tracker_config=ModelTrackerConfig(
611633
tracking_mode=TrackingMode.ID_ONLY,
612634
delete_on_read=True,
635+
tracker_type=TrackerType.TREC,
613636
),
614637
model_inputs=[
615638
ModelInput(
@@ -653,6 +676,7 @@ def test_tracker_not_initialized(
653676
model_tracker_config=ModelTrackerConfig(
654677
tracking_mode=TrackingMode.ID_ONLY,
655678
delete_on_read=True,
679+
tracker_type=TrackerType.TREC,
656680
),
657681
model_inputs=[
658682
ModelInput(
@@ -715,6 +739,7 @@ def test_tracker_id_mode(
715739
model_tracker_config=ModelTrackerConfig(
716740
tracking_mode=TrackingMode.EMBEDDING,
717741
delete_on_read=True,
742+
tracker_type=TrackerType.TREC,
718743
),
719744
model_inputs=[
720745
ModelInput(
@@ -763,6 +788,7 @@ def test_tracker_id_mode(
763788
model_tracker_config=ModelTrackerConfig(
764789
tracking_mode=TrackingMode.EMBEDDING,
765790
delete_on_read=True,
791+
tracker_type=TrackerType.TREC,
766792
),
767793
model_inputs=[
768794
ModelInput(
@@ -808,6 +834,7 @@ def test_tracker_id_mode(
808834
model_tracker_config=ModelTrackerConfig(
809835
tracking_mode=TrackingMode.EMBEDDING,
810836
delete_on_read=True,
837+
tracker_type=TrackerType.TREC,
811838
),
812839
model_inputs=[
813840
ModelInput(
@@ -881,6 +908,7 @@ def test_tracker_embedding_mode(
881908
model_tracker_config=ModelTrackerConfig(
882909
tracking_mode=TrackingMode.EMBEDDING,
883910
delete_on_read=True,
911+
tracker_type=TrackerType.TREC,
884912
),
885913
model_inputs=[
886914
# First input: f1 has values 0,2,4,6 and f2 has values 8,10,12,14
@@ -975,6 +1003,7 @@ def test_tracker_embedding_mode(
9751003
model_tracker_config=ModelTrackerConfig(
9761004
tracking_mode=TrackingMode.ID_ONLY,
9771005
delete_on_read=True,
1006+
tracker_type=TrackerType.TREC,
9781007
),
9791008
model_inputs=[
9801009
# First input: f1 has values 0,2,4,6 and f2 has values 8,10,12,14
@@ -1088,6 +1117,7 @@ def test_multiple_get(
10881117
tracking_mode=TrackingMode.ID_ONLY,
10891118
delete_on_read=True,
10901119
consumers=["A", "B"],
1120+
tracker_type=TrackerType.TREC,
10911121
),
10921122
model_inputs=[
10931123
# First input: f1 has values 0,2,4,6 and f2 has values 8,10,12,14
@@ -1185,6 +1215,7 @@ def test_multiple_get(
11851215
tracking_mode=TrackingMode.ID_ONLY,
11861216
delete_on_read=False,
11871217
consumers=["A", "B"],
1218+
tracker_type=TrackerType.TREC,
11881219
),
11891220
model_inputs=[
11901221
# First input: f1 has values 0,2,4,6 and f2 has values 8,10,12,14
@@ -1282,6 +1313,7 @@ def test_multiple_get(
12821313
tracking_mode=TrackingMode.ID_ONLY,
12831314
delete_on_read=True,
12841315
consumers=["A", "B"],
1316+
tracker_type=TrackerType.TREC,
12851317
),
12861318
model_inputs=[
12871319
# First input: f1 has values 0,2,4,6 and f2 has values 8,10,12,14
@@ -1379,6 +1411,7 @@ def test_multiple_get(
13791411
tracking_mode=TrackingMode.ID_ONLY,
13801412
delete_on_read=False,
13811413
consumers=["A", "B"],
1414+
tracker_type=TrackerType.TREC,
13821415
),
13831416
model_inputs=[
13841417
# First input: f1 has values 0,2,4,6 and f2 has values 8,10,12,14
@@ -1486,6 +1519,7 @@ def test_multiple_consumers(
14861519
model_tracker_config=ModelTrackerConfig(
14871520
tracking_mode=TrackingMode.MOMENTUM_LAST,
14881521
delete_on_read=True,
1522+
tracker_type=TrackerType.TREC,
14891523
),
14901524
model_inputs=[
14911525
ModelInput(
@@ -1525,6 +1559,7 @@ def test_multiple_consumers(
15251559
model_tracker_config=ModelTrackerConfig(
15261560
tracking_mode=TrackingMode.MOMENTUM_LAST,
15271561
delete_on_read=True,
1562+
tracker_type=TrackerType.TREC,
15281563
),
15291564
model_inputs=[
15301565
ModelInput(
@@ -1581,6 +1616,7 @@ def test_duplication_with_momentum(
15811616
model_tracker_config=ModelTrackerConfig(
15821617
tracking_mode=TrackingMode.MOMENTUM_DIFF,
15831618
delete_on_read=True,
1619+
tracker_type=TrackerType.TREC,
15841620
),
15851621
model_inputs=[
15861622
ModelInput(
@@ -1620,6 +1656,7 @@ def test_duplication_with_momentum(
16201656
model_tracker_config=ModelTrackerConfig(
16211657
tracking_mode=TrackingMode.ROWWISE_ADAGRAD,
16221658
delete_on_read=True,
1659+
tracker_type=TrackerType.TREC,
16231660
),
16241661
model_inputs=[
16251662
ModelInput(
@@ -1679,6 +1716,7 @@ def _test_fqn_to_feature_names(
16791716
)
16801717

16811718
dt = dt_model.get_model_tracker()
1719+
assert isinstance(dt, ModelDeltaTrackerTrec)
16821720
unittest.TestCase().assertEqual(
16831721
dt.fqn_to_feature_names(), output_params.expected_fqn_to_feature_names
16841722
)
@@ -1732,6 +1770,7 @@ def _test_id_mode(
17321770
)
17331771
features_list = model_input_generator(test_params.model_inputs, rank)
17341772
dt = dt_model.get_model_tracker()
1773+
assert isinstance(dt, ModelDeltaTrackerTrec)
17351774
for features in features_list:
17361775
tracked_out = dt_model(features)
17371776
baseline_out = baseline_model(features)
@@ -1826,6 +1865,7 @@ def _test_embedding_mode(
18261865
# Only proceed with the rest of the test if models were created successfully
18271866
features_list = model_input_generator(test_params.model_inputs, rank)
18281867
dt = dt_model.get_model_tracker()
1868+
assert isinstance(dt, ModelDeltaTrackerTrec)
18291869

18301870
orig_emb1 = (
18311871
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ec`.
@@ -1939,6 +1979,7 @@ def _test_multiple_get(
19391979
)
19401980
features_list = model_input_generator(test_params.model_inputs, rank)
19411981
dt = dt_model.get_model_tracker()
1982+
assert isinstance(dt, ModelDeltaTrackerTrec)
19421983
table_fqns = dt.fqn_to_feature_names().keys()
19431984
table_fqns_list = list(table_fqns)
19441985
expected_emb1 = torch.tensor([])
@@ -2019,6 +2060,7 @@ def _test_multiple_consumer(
20192060
)
20202061
features_list = model_input_generator(test_params.model_inputs, rank)
20212062
dt = dt_model.get_model_tracker()
2063+
assert isinstance(dt, ModelDeltaTrackerTrec)
20222064
table_fqns = dt.fqn_to_feature_names().keys()
20232065
table_fqns_list = list(table_fqns)
20242066

@@ -2082,6 +2124,7 @@ def _test_duplication_with_momentum(
20822124
baseline_opt = torch.optim.Adam(baseline_model.parameters(), lr=0.1)
20832125
features_list = model_input_generator(test_params.model_inputs, rank)
20842126
dt = dt_model.get_model_tracker()
2127+
assert isinstance(dt, ModelDeltaTrackerTrec)
20852128
table_fqns = dt.fqn_to_feature_names().keys()
20862129
table_fqns_list = list(table_fqns)
20872130
for features in features_list:
@@ -2147,6 +2190,7 @@ def _test_duplication_with_rowwise_adagrad(
21472190
features_list = model_input_generator(test_params.model_inputs, rank)
21482191

21492192
dt = dt_model.get_model_tracker()
2193+
assert isinstance(dt, ModelDeltaTrackerTrec)
21502194
table_fqns = dt.fqn_to_feature_names().keys()
21512195
table_fqns_list = list(table_fqns)
21522196

0 commit comments

Comments
 (0)