2424from torchrec .distributed .embedding import EmbeddingCollectionSharder
2525from torchrec .distributed .embedding_types import ModuleSharder , ShardingType
2626from torchrec .distributed .embeddingbag import EmbeddingBagCollectionSharder
27+ from torchrec .distributed .model_tracker .model_delta_tracker import ModelDeltaTrackerTrec
2728from torchrec .distributed .model_tracker .tests .utils import (
2829 EmbeddingTableProps ,
2930 generate_planner_constraints ,
3031 TestEBCModel ,
3132 TestECModel ,
3233)
33- from torchrec .distributed .model_tracker .types import ModelTrackerConfig , TrackingMode
34+ from torchrec .distributed .model_tracker .types import (
35+ ModelTrackerConfig ,
36+ TrackerType ,
37+ TrackingMode ,
38+ )
3439
3540from torchrec .distributed .planner import EmbeddingShardingPlanner , Topology
3641from torchrec .distributed .test_utils .multi_process import (
@@ -228,7 +233,9 @@ def __init__(self, methodName="runTest") -> None:
228233 sharding = ShardingType .ROW_WISE ,
229234 ),
230235 ],
231- model_tracker_config = ModelTrackerConfig (),
236+ model_tracker_config = ModelTrackerConfig (
237+ tracker_type = TrackerType .TREC
238+ ),
232239 ),
233240 FqnToFeatureNamesOutputTestParams (
234241 expected_fqn_to_feature_names = {
@@ -263,7 +270,9 @@ def __init__(self, methodName="runTest") -> None:
263270 sharding = ShardingType .ROW_WISE ,
264271 ),
265272 ],
266- model_tracker_config = ModelTrackerConfig (),
273+ model_tracker_config = ModelTrackerConfig (
274+ tracker_type = TrackerType .TREC
275+ ),
267276 ),
268277 FqnToFeatureNamesOutputTestParams (
269278 expected_fqn_to_feature_names = {
@@ -296,7 +305,9 @@ def __init__(self, methodName="runTest") -> None:
296305 sharding = ShardingType .ROW_WISE ,
297306 ),
298307 ],
299- model_tracker_config = ModelTrackerConfig (),
308+ model_tracker_config = ModelTrackerConfig (
309+ tracker_type = TrackerType .TREC ,
310+ ),
300311 ),
301312 FqnToFeatureNamesOutputTestParams (
302313 expected_fqn_to_feature_names = {
@@ -332,7 +343,8 @@ def __init__(self, methodName="runTest") -> None:
332343 ),
333344 ],
334345 model_tracker_config = ModelTrackerConfig (
335- fqns_to_skip = ["sparse_table_1" ]
346+ fqns_to_skip = ["sparse_table_1" ],
347+ tracker_type = TrackerType .TREC ,
336348 ),
337349 ),
338350 FqnToFeatureNamesOutputTestParams (
@@ -368,7 +380,8 @@ def __init__(self, methodName="runTest") -> None:
368380 ),
369381 ],
370382 model_tracker_config = ModelTrackerConfig (
371- fqns_to_skip = ["embedding_bags" ]
383+ fqns_to_skip = ["embedding_bags" ],
384+ tracker_type = TrackerType .TREC ,
372385 ),
373386 ),
374387 FqnToFeatureNamesOutputTestParams (
@@ -399,7 +412,10 @@ def __init__(self, methodName="runTest") -> None:
399412 sharding = ShardingType .ROW_WISE ,
400413 ),
401414 ],
402- model_tracker_config = ModelTrackerConfig (fqns_to_skip = ["ec" ]),
415+ model_tracker_config = ModelTrackerConfig (
416+ fqns_to_skip = ["ec" ],
417+ tracker_type = TrackerType .TREC ,
418+ ),
403419 ),
404420 FqnToFeatureNamesOutputTestParams (
405421 expected_fqn_to_feature_names = {},
@@ -440,7 +456,9 @@ def test_fqn_to_feature_names(
440456 sharding = ShardingType .ROW_WISE ,
441457 ),
442458 ],
443- model_tracker_config = ModelTrackerConfig (),
459+ model_tracker_config = ModelTrackerConfig (
460+ tracker_type = TrackerType .TREC ,
461+ ),
444462 ),
445463 TrackerNotInitOutputTestParams (
446464 dmp_tracker_atter = "get_model_tracker" ,
@@ -461,7 +479,9 @@ def test_fqn_to_feature_names(
461479 sharding = ShardingType .ROW_WISE ,
462480 ),
463481 ],
464- model_tracker_config = ModelTrackerConfig (),
482+ model_tracker_config = ModelTrackerConfig (
483+ tracker_type = TrackerType .TREC ,
484+ ),
465485 ),
466486 TrackerNotInitOutputTestParams (
467487 dmp_tracker_atter = "get_unique" ,
@@ -514,6 +534,7 @@ def test_tracker_not_initialized(
514534 model_tracker_config = ModelTrackerConfig (
515535 tracking_mode = TrackingMode .ID_ONLY ,
516536 delete_on_read = True ,
537+ tracker_type = TrackerType .TREC ,
517538 ),
518539 model_inputs = [
519540 ModelInput (
@@ -563,6 +584,7 @@ def test_tracker_not_initialized(
563584 model_tracker_config = ModelTrackerConfig (
564585 tracking_mode = TrackingMode .ID_ONLY ,
565586 delete_on_read = True ,
587+ tracker_type = TrackerType .TREC ,
566588 ),
567589 model_inputs = [
568590 ModelInput (
@@ -610,6 +632,7 @@ def test_tracker_not_initialized(
610632 model_tracker_config = ModelTrackerConfig (
611633 tracking_mode = TrackingMode .ID_ONLY ,
612634 delete_on_read = True ,
635+ tracker_type = TrackerType .TREC ,
613636 ),
614637 model_inputs = [
615638 ModelInput (
@@ -653,6 +676,7 @@ def test_tracker_not_initialized(
653676 model_tracker_config = ModelTrackerConfig (
654677 tracking_mode = TrackingMode .ID_ONLY ,
655678 delete_on_read = True ,
679+ tracker_type = TrackerType .TREC ,
656680 ),
657681 model_inputs = [
658682 ModelInput (
@@ -715,6 +739,7 @@ def test_tracker_id_mode(
715739 model_tracker_config = ModelTrackerConfig (
716740 tracking_mode = TrackingMode .EMBEDDING ,
717741 delete_on_read = True ,
742+ tracker_type = TrackerType .TREC ,
718743 ),
719744 model_inputs = [
720745 ModelInput (
@@ -763,6 +788,7 @@ def test_tracker_id_mode(
763788 model_tracker_config = ModelTrackerConfig (
764789 tracking_mode = TrackingMode .EMBEDDING ,
765790 delete_on_read = True ,
791+ tracker_type = TrackerType .TREC ,
766792 ),
767793 model_inputs = [
768794 ModelInput (
@@ -808,6 +834,7 @@ def test_tracker_id_mode(
808834 model_tracker_config = ModelTrackerConfig (
809835 tracking_mode = TrackingMode .EMBEDDING ,
810836 delete_on_read = True ,
837+ tracker_type = TrackerType .TREC ,
811838 ),
812839 model_inputs = [
813840 ModelInput (
@@ -881,6 +908,7 @@ def test_tracker_embedding_mode(
881908 model_tracker_config = ModelTrackerConfig (
882909 tracking_mode = TrackingMode .EMBEDDING ,
883910 delete_on_read = True ,
911+ tracker_type = TrackerType .TREC ,
884912 ),
885913 model_inputs = [
886914 # First input: f1 has values 0,2,4,6 and f2 has values 8,10,12,14
@@ -975,6 +1003,7 @@ def test_tracker_embedding_mode(
9751003 model_tracker_config = ModelTrackerConfig (
9761004 tracking_mode = TrackingMode .ID_ONLY ,
9771005 delete_on_read = True ,
1006+ tracker_type = TrackerType .TREC ,
9781007 ),
9791008 model_inputs = [
9801009 # First input: f1 has values 0,2,4,6 and f2 has values 8,10,12,14
@@ -1088,6 +1117,7 @@ def test_multiple_get(
10881117 tracking_mode = TrackingMode .ID_ONLY ,
10891118 delete_on_read = True ,
10901119 consumers = ["A" , "B" ],
1120+ tracker_type = TrackerType .TREC ,
10911121 ),
10921122 model_inputs = [
10931123 # First input: f1 has values 0,2,4,6 and f2 has values 8,10,12,14
@@ -1185,6 +1215,7 @@ def test_multiple_get(
11851215 tracking_mode = TrackingMode .ID_ONLY ,
11861216 delete_on_read = False ,
11871217 consumers = ["A" , "B" ],
1218+ tracker_type = TrackerType .TREC ,
11881219 ),
11891220 model_inputs = [
11901221 # First input: f1 has values 0,2,4,6 and f2 has values 8,10,12,14
@@ -1282,6 +1313,7 @@ def test_multiple_get(
12821313 tracking_mode = TrackingMode .ID_ONLY ,
12831314 delete_on_read = True ,
12841315 consumers = ["A" , "B" ],
1316+ tracker_type = TrackerType .TREC ,
12851317 ),
12861318 model_inputs = [
12871319 # First input: f1 has values 0,2,4,6 and f2 has values 8,10,12,14
@@ -1379,6 +1411,7 @@ def test_multiple_get(
13791411 tracking_mode = TrackingMode .ID_ONLY ,
13801412 delete_on_read = False ,
13811413 consumers = ["A" , "B" ],
1414+ tracker_type = TrackerType .TREC ,
13821415 ),
13831416 model_inputs = [
13841417 # First input: f1 has values 0,2,4,6 and f2 has values 8,10,12,14
@@ -1486,6 +1519,7 @@ def test_multiple_consumers(
14861519 model_tracker_config = ModelTrackerConfig (
14871520 tracking_mode = TrackingMode .MOMENTUM_LAST ,
14881521 delete_on_read = True ,
1522+ tracker_type = TrackerType .TREC ,
14891523 ),
14901524 model_inputs = [
14911525 ModelInput (
@@ -1525,6 +1559,7 @@ def test_multiple_consumers(
15251559 model_tracker_config = ModelTrackerConfig (
15261560 tracking_mode = TrackingMode .MOMENTUM_LAST ,
15271561 delete_on_read = True ,
1562+ tracker_type = TrackerType .TREC ,
15281563 ),
15291564 model_inputs = [
15301565 ModelInput (
@@ -1581,6 +1616,7 @@ def test_duplication_with_momentum(
15811616 model_tracker_config = ModelTrackerConfig (
15821617 tracking_mode = TrackingMode .MOMENTUM_DIFF ,
15831618 delete_on_read = True ,
1619+ tracker_type = TrackerType .TREC ,
15841620 ),
15851621 model_inputs = [
15861622 ModelInput (
@@ -1620,6 +1656,7 @@ def test_duplication_with_momentum(
16201656 model_tracker_config = ModelTrackerConfig (
16211657 tracking_mode = TrackingMode .ROWWISE_ADAGRAD ,
16221658 delete_on_read = True ,
1659+ tracker_type = TrackerType .TREC ,
16231660 ),
16241661 model_inputs = [
16251662 ModelInput (
@@ -1679,6 +1716,7 @@ def _test_fqn_to_feature_names(
16791716 )
16801717
16811718 dt = dt_model .get_model_tracker ()
1719+ assert isinstance (dt , ModelDeltaTrackerTrec )
16821720 unittest .TestCase ().assertEqual (
16831721 dt .fqn_to_feature_names (), output_params .expected_fqn_to_feature_names
16841722 )
@@ -1732,6 +1770,7 @@ def _test_id_mode(
17321770 )
17331771 features_list = model_input_generator (test_params .model_inputs , rank )
17341772 dt = dt_model .get_model_tracker ()
1773+ assert isinstance (dt , ModelDeltaTrackerTrec )
17351774 for features in features_list :
17361775 tracked_out = dt_model (features )
17371776 baseline_out = baseline_model (features )
@@ -1826,6 +1865,7 @@ def _test_embedding_mode(
18261865 # Only proceed with the rest of the test if models were created successfully
18271866 features_list = model_input_generator (test_params .model_inputs , rank )
18281867 dt = dt_model .get_model_tracker ()
1868+ assert isinstance (dt , ModelDeltaTrackerTrec )
18291869
18301870 orig_emb1 = (
18311871 # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ec`.
@@ -1939,6 +1979,7 @@ def _test_multiple_get(
19391979 )
19401980 features_list = model_input_generator (test_params .model_inputs , rank )
19411981 dt = dt_model .get_model_tracker ()
1982+ assert isinstance (dt , ModelDeltaTrackerTrec )
19421983 table_fqns = dt .fqn_to_feature_names ().keys ()
19431984 table_fqns_list = list (table_fqns )
19441985 expected_emb1 = torch .tensor ([])
@@ -2019,6 +2060,7 @@ def _test_multiple_consumer(
20192060 )
20202061 features_list = model_input_generator (test_params .model_inputs , rank )
20212062 dt = dt_model .get_model_tracker ()
2063+ assert isinstance (dt , ModelDeltaTrackerTrec )
20222064 table_fqns = dt .fqn_to_feature_names ().keys ()
20232065 table_fqns_list = list (table_fqns )
20242066
@@ -2082,6 +2124,7 @@ def _test_duplication_with_momentum(
20822124 baseline_opt = torch .optim .Adam (baseline_model .parameters (), lr = 0.1 )
20832125 features_list = model_input_generator (test_params .model_inputs , rank )
20842126 dt = dt_model .get_model_tracker ()
2127+ assert isinstance (dt , ModelDeltaTrackerTrec )
20852128 table_fqns = dt .fqn_to_feature_names ().keys ()
20862129 table_fqns_list = list (table_fqns )
20872130 for features in features_list :
@@ -2147,6 +2190,7 @@ def _test_duplication_with_rowwise_adagrad(
21472190 features_list = model_input_generator (test_params .model_inputs , rank )
21482191
21492192 dt = dt_model .get_model_tracker ()
2193+ assert isinstance (dt , ModelDeltaTrackerTrec )
21502194 table_fqns = dt .fqn_to_feature_names ().keys ()
21512195 table_fqns_list = list (table_fqns )
21522196
0 commit comments