diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 2f6c6b9ed..5e9a96ca0 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -1605,7 +1605,7 @@ def compute_and_output_dist( ): embs = lookup(features) if self.post_lookup_tracker_fn is not None: - self.post_lookup_tracker_fn(features, embs, self) + self.post_lookup_tracker_fn(features, embs, self, None) with maybe_annotate_embedding_event( EmbeddingEvent.OUTPUT_DIST, self._module_fqn, sharding_type diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index d0a5ef920..5e92fa528 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -391,7 +391,15 @@ def __init__( self._lookups: List[nn.Module] = [] self._output_dists: List[nn.Module] = [] self.post_lookup_tracker_fn: Optional[ - Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None] + Callable[ + [ + KeyedJaggedTensor, + torch.Tensor, + Optional[nn.Module], + Optional[torch.Tensor], + ], + None, + ] ] = None self.post_odist_tracker_fn: Optional[Callable[..., None]] = None @@ -445,7 +453,13 @@ def train(self, mode: bool = True): # pyre-ignore[3] def register_post_lookup_tracker_fn( self, record_fn: Callable[ - [KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None + [ + KeyedJaggedTensor, + torch.Tensor, + Optional[nn.Module], + Optional[torch.Tensor], + ], + None, ], ) -> None: """ diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index fd6117884..754b9e6fa 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -1671,7 +1671,7 @@ def compute_and_output_dist( ): embs = lookup(features) if self.post_lookup_tracker_fn is not None: - self.post_lookup_tracker_fn(features, embs, self) + self.post_lookup_tracker_fn(features, embs, self, None) with maybe_annotate_embedding_event( EmbeddingEvent.OUTPUT_DIST, diff --git a/torchrec/distributed/mc_modules.py b/torchrec/distributed/mc_modules.py index 2dfc8f0a1..017025796 100644 --- a/torchrec/distributed/mc_modules.py +++ b/torchrec/distributed/mc_modules.py @@ -238,7 +238,15 @@ def __init__( self._use_index_dedup = use_index_dedup self._initialize_torch_state() self.post_lookup_tracker_fn: Optional[ - Callable[[KeyedJaggedTensor, torch.Tensor], None] + Callable[ + [ + KeyedJaggedTensor, + torch.Tensor, + Optional[nn.Module], + Optional[torch.Tensor], + ], + None, + ] ] = None def _initialize_torch_state(self) -> None: @@ -756,6 +764,8 @@ def compute( if self.post_lookup_tracker_fn is not None: self.post_lookup_tracker_fn( KeyedJaggedTensor.from_jt_dict(mc_input), + torch.empty(0), + None, mcm._hash_zch_identities.index_select( dim=0, index=mc_input[table].values() ), @@ -782,6 +792,8 @@ def compute( if self.post_lookup_tracker_fn is not None: self.post_lookup_tracker_fn( KeyedJaggedTensor.from_jt_dict(mc_input), + torch.empty(0), + None, mcm._hash_zch_identities.index_select(dim=0, index=values), ) @@ -876,7 +888,15 @@ def unsharded_module_type(self) -> Type[ManagedCollisionCollection]: def register_post_lookup_tracker_fn( self, - record_fn: Callable[[KeyedJaggedTensor, torch.Tensor], None], + record_fn: Callable[ + [ + KeyedJaggedTensor, + torch.Tensor, + Optional[nn.Module], + Optional[torch.Tensor], + ], + None, + ], ) -> None: """ Register a function to be called after lookup is done. This is used for diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index 45f2288e9..e17e03727 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -29,8 +29,18 @@ from torch.nn.modules.module import _IncompatibleKeys from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.comm import get_local_size -from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTrackerTrec -from torchrec.distributed.model_tracker.types import ModelTrackerConfig, UniqueRows +from torchrec.distributed.model_tracker.model_delta_tracker import ( + ModelDeltaTracker, + ModelDeltaTrackerTrec, +) +from torchrec.distributed.model_tracker.trackers.raw_id_tracker import RawIdTracker +from torchrec.distributed.model_tracker.types import ( + DeltaTrackerConfig, + ModelTrackerConfigs, + RawIdTrackerConfig, + Trackers, + UniqueRows, +) from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology from torchrec.distributed.sharding_plan import get_default_sharders @@ -53,6 +63,7 @@ none_throws, sharded_model_copy, ) + from torchrec.optim.fused import FusedOptimizerModule from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer @@ -240,7 +251,7 @@ def __init__( init_data_parallel: bool = True, init_parameters: bool = True, data_parallel_wrapper: Optional[DataParallelWrapper] = None, - model_tracker_config: Optional[ModelTrackerConfig] = None, + model_tracker_configs: Optional[ModelTrackerConfigs] = None, ) -> None: super().__init__() torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}") @@ -294,11 +305,18 @@ def __init__( if init_data_parallel: self.init_data_parallel() - self.model_delta_tracker: Optional[ModelDeltaTrackerTrec] = ( - self._init_delta_tracker(model_tracker_config, self._dmp_wrapped_module) - if model_tracker_config is not None - else None - ) + self.model_trackers: Dict[str, ModelDeltaTracker] = {} + + if ( + model_tracker_configs is not None + and model_tracker_configs.raw_id_tracker_config is not None + ): + self.model_trackers[Trackers.RAW_ID_TRACKER.name] = ( + self._init_raw_id_tracker( + model_tracker_configs.raw_id_tracker_config, + self._dmp_wrapped_module, + ) + ) @property def module(self) -> nn.Module: @@ -321,14 +339,14 @@ def module(self, value: nn.Module) -> None: # pyre-ignore [2, 3] def forward(self, *args, **kwargs) -> Any: - if self.model_delta_tracker is not None: + for tracker in self.model_trackers.values(): # The step() call advances the internal batch counter so that subsequent ID tracking and delta # retrieval operations can be properly organized by batch boundaries. # Context: ModelDeltaTracker tracks unique embedding IDs (and optionally embeddings / states) # useful for calculating topk rows for checkpointing and for updating fresh embedding weights # between predictors and trainers in online training scenarios. - self.model_delta_tracker.step() + tracker.step() # Model forward for DMP wrapped model return self._dmp_wrapped_module(*args, **kwargs) @@ -369,16 +387,25 @@ def _init_dmp(self, module: nn.Module) -> nn.Module: return self._shard_modules_impl(module) def _init_delta_tracker( - self, model_tracker_config: ModelTrackerConfig, module: nn.Module - ) -> ModelDeltaTrackerTrec: + self, delta_tracker_config: DeltaTrackerConfig, module: nn.Module + ) -> ModelDeltaTracker: # Init delta tracker if config is provided return ModelDeltaTrackerTrec( model=module, - consumers=model_tracker_config.consumers, - delete_on_read=model_tracker_config.delete_on_read, - auto_compact=model_tracker_config.auto_compact, - mode=model_tracker_config.tracking_mode, - fqns_to_skip=model_tracker_config.fqns_to_skip, + consumers=delta_tracker_config.consumers, + delete_on_read=delta_tracker_config.delete_on_read, + auto_compact=delta_tracker_config.auto_compact, + mode=delta_tracker_config.tracking_mode, + fqns_to_skip=delta_tracker_config.fqns_to_skip, + ) + + def _init_raw_id_tracker( + self, raw_id_tracker_config: RawIdTrackerConfig, module: nn.Module + ) -> RawIdTracker: + return RawIdTracker( + model=module, + delete_on_read=raw_id_tracker_config.delete_on_read, + fqns_to_skip=raw_id_tracker_config.fqns_to_skip, ) def _init_optim(self, module: nn.Module) -> CombinedOptimizer: @@ -458,36 +485,25 @@ def init_parameters(module: nn.Module) -> None: module.apply(init_parameters) def init_torchrec_delta_tracker( - self, model_tracker_config: ModelTrackerConfig - ) -> ModelDeltaTrackerTrec: + self, delta_tracker_config: DeltaTrackerConfig + ) -> ModelDeltaTracker: """ Initializes the model delta tracker if it doesn't exists. """ - if self.model_delta_tracker is None: - self.model_delta_tracker = self._init_delta_tracker( - model_tracker_config, self._dmp_wrapped_module + if Trackers.DELTA_TRACKER.name not in self.model_trackers: + self.model_trackers[Trackers.DELTA_TRACKER.name] = self._init_delta_tracker( + delta_tracker_config, self._dmp_wrapped_module ) - return none_throws(self.model_delta_tracker) - - def get_model_tracker(self) -> ModelDeltaTrackerTrec: - """ - Returns the model tracker if it exists. - """ - - assert ( - self.model_delta_tracker is not None - ), "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init." - return self.model_delta_tracker + return self.model_trackers[Trackers.DELTA_TRACKER.name] - def get_unique(self, consumer: Optional[str] = None) -> Dict[str, UniqueRows]: + def get_delta_tracker(self) -> Optional[ModelDeltaTracker]: """ - Returns the delta rows for the given consumer. + Returns the delta tracker if it exists. """ - assert ( - self.model_delta_tracker is not None - ), "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init." - return self.model_delta_tracker.get_unique(consumer) + if Trackers.DELTA_TRACKER.name in self.model_trackers: + return self.model_trackers[Trackers.DELTA_TRACKER.name] + return None def sparse_grad_parameter_names( self, destination: Optional[List[str]] = None, prefix: str = "" diff --git a/torchrec/distributed/model_tracker/__init__.py b/torchrec/distributed/model_tracker/__init__.py index 2895e218a..81a6a8e9a 100644 --- a/torchrec/distributed/model_tracker/__init__.py +++ b/torchrec/distributed/model_tracker/__init__.py @@ -28,7 +28,8 @@ ) from torchrec.distributed.model_tracker.types import ( IndexedLookup, # noqa - ModelTrackerConfig, # noqa + ModelTrackerConfigs, # noqa + Trackers, # noqa TrackingMode, # noqa UniqueRows, # noqa UpdateMode, # noqa diff --git a/torchrec/distributed/model_tracker/delta_store.py b/torchrec/distributed/model_tracker/delta_store.py index bd2ee1b27..cfac71b8c 100644 --- a/torchrec/distributed/model_tracker/delta_store.py +++ b/torchrec/distributed/model_tracker/delta_store.py @@ -91,6 +91,7 @@ def append( fqn: str, ids: torch.Tensor, states: Optional[torch.Tensor], + raw_ids: Optional[torch.Tensor] = None, ) -> None: """ Append a batch of ids and states to the store for a specific table. @@ -162,10 +163,11 @@ def append( fqn: str, ids: torch.Tensor, states: Optional[torch.Tensor], + raw_ids: Optional[torch.Tensor] = None, ) -> None: table_fqn_lookup = self.per_fqn_lookups.get(fqn, []) table_fqn_lookup.append( - IndexedLookup(batch_idx=batch_idx, ids=ids, states=states) + IndexedLookup(batch_idx=batch_idx, ids=ids, states=states, raw_ids=raw_ids) ) self.per_fqn_lookups[fqn] = table_fqn_lookup @@ -224,6 +226,20 @@ def compact(self, start_idx: int, end_idx: int) -> None: ) self.per_fqn_lookups = new_per_fqn_lookups + def get_indexed_lookups( + self, start_idx: int, end_idx: int + ) -> Dict[str, List[IndexedLookup]]: + r""" + Return all unique/delta ids per table from the Delta Store. + """ + per_fqn_lookups: Dict[str, List[IndexedLookup]] = {} + for table_fqn, lookups in self.per_fqn_lookups.items(): + indexices = [h.batch_idx for h in lookups] + index_l = bisect_left(indexices, start_idx) + index_r = bisect_left(indexices, end_idx) + per_fqn_lookups[table_fqn] = lookups[index_l:index_r] + return per_fqn_lookups + def get_unique(self, from_idx: int = 0) -> Dict[str, UniqueRows]: r""" Return all unique/delta ids per table from the Delta Store. diff --git a/torchrec/distributed/model_tracker/model_delta_tracker.py b/torchrec/distributed/model_tracker/model_delta_tracker.py index 50a9bd250..3e444ee4e 100644 --- a/torchrec/distributed/model_tracker/model_delta_tracker.py +++ b/torchrec/distributed/model_tracker/model_delta_tracker.py @@ -84,6 +84,7 @@ def record_lookup( kjt: KeyedJaggedTensor, states: torch.Tensor, emb_module: Optional[nn.Module] = None, + raw_ids: Optional[torch.Tensor] = None, ) -> None: """ Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states. @@ -131,6 +132,13 @@ def clear(self, consumer: Optional[str] = None) -> None: """ pass + @abstractmethod + def step(self) -> None: + """ + Advance the batch index for all consumers. + """ + pass + class ModelDeltaTrackerTrec(ModelDeltaTracker): r""" @@ -244,6 +252,7 @@ def record_lookup( kjt: KeyedJaggedTensor, states: torch.Tensor, emb_module: Optional[nn.Module] = None, + raw_ids: Optional[torch.Tensor] = None, ) -> None: """ Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states. diff --git a/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py b/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py index a46268a3b..c942b6385 100644 --- a/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py +++ b/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py @@ -24,13 +24,14 @@ from torchrec.distributed.embedding import EmbeddingCollectionSharder from torchrec.distributed.embedding_types import ModuleSharder, ShardingType from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTrackerTrec from torchrec.distributed.model_tracker.tests.utils import ( EmbeddingTableProps, generate_planner_constraints, TestEBCModel, TestECModel, ) -from torchrec.distributed.model_tracker.types import ModelTrackerConfig, TrackingMode +from torchrec.distributed.model_tracker.types import DeltaTrackerConfig, TrackingMode from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology from torchrec.distributed.test_utils.multi_process import ( @@ -81,7 +82,7 @@ class ModelInput: class ModelDeltaTrackerInputTestParams: # input parameters embedding_config_type: Union[Type[EmbeddingConfig], Type[EmbeddingBagConfig]] - model_tracker_config: ModelTrackerConfig + delta_tracker_config: DeltaTrackerConfig embedding_tables: List[EmbeddingTableProps] model_inputs: List[ModelInput] = field(default_factory=list) consumers: List[str] = field(default_factory=list) @@ -131,7 +132,7 @@ def get_models( embedding_config_type: Union[Type[EmbeddingConfig], Type[EmbeddingBagConfig]], tables: Iterable[EmbeddingTableProps], optimizer_type: OptimType = OptimType.ADAM, - config: Optional[ModelTrackerConfig] = None, + config: Optional[DeltaTrackerConfig] = None, ) -> Tuple[DistributedModelParallel, DistributedModelParallel]: # Create the model torch.manual_seed(0) @@ -179,8 +180,9 @@ def get_models( env=torchrec.distributed.ShardingEnv.from_process_group(ctx.pg), plan=plan, sharders=sharders, - model_tracker_config=config, ) + if config: + dt_dmp.init_torchrec_delta_tracker(delta_tracker_config=config) torch.manual_seed(0) baseline_module = generate_test_models(embedding_config_type, tables) @@ -228,7 +230,7 @@ def __init__(self, methodName="runTest") -> None: sharding=ShardingType.ROW_WISE, ), ], - model_tracker_config=ModelTrackerConfig(), + delta_tracker_config=DeltaTrackerConfig(), ), FqnToFeatureNamesOutputTestParams( expected_fqn_to_feature_names={ @@ -263,7 +265,7 @@ def __init__(self, methodName="runTest") -> None: sharding=ShardingType.ROW_WISE, ), ], - model_tracker_config=ModelTrackerConfig(), + delta_tracker_config=DeltaTrackerConfig(), ), FqnToFeatureNamesOutputTestParams( expected_fqn_to_feature_names={ @@ -296,7 +298,7 @@ def __init__(self, methodName="runTest") -> None: sharding=ShardingType.ROW_WISE, ), ], - model_tracker_config=ModelTrackerConfig(), + delta_tracker_config=DeltaTrackerConfig(), ), FqnToFeatureNamesOutputTestParams( expected_fqn_to_feature_names={ @@ -331,8 +333,8 @@ def __init__(self, methodName="runTest") -> None: sharding=ShardingType.ROW_WISE, ), ], - model_tracker_config=ModelTrackerConfig( - fqns_to_skip=["sparse_table_1"] + delta_tracker_config=DeltaTrackerConfig( + fqns_to_skip=["sparse_table_1"], ), ), FqnToFeatureNamesOutputTestParams( @@ -367,8 +369,8 @@ def __init__(self, methodName="runTest") -> None: sharding=ShardingType.ROW_WISE, ), ], - model_tracker_config=ModelTrackerConfig( - fqns_to_skip=["embedding_bags"] + delta_tracker_config=DeltaTrackerConfig( + fqns_to_skip=["embedding_bags"], ), ), FqnToFeatureNamesOutputTestParams( @@ -399,7 +401,9 @@ def __init__(self, methodName="runTest") -> None: sharding=ShardingType.ROW_WISE, ), ], - model_tracker_config=ModelTrackerConfig(fqns_to_skip=["ec"]), + delta_tracker_config=DeltaTrackerConfig( + fqns_to_skip=["ec"], + ), ), FqnToFeatureNamesOutputTestParams( expected_fqn_to_feature_names={}, @@ -440,31 +444,10 @@ def test_fqn_to_feature_names( sharding=ShardingType.ROW_WISE, ), ], - model_tracker_config=ModelTrackerConfig(), - ), - TrackerNotInitOutputTestParams( - dmp_tracker_atter="get_model_tracker", - ), - ), - ( - "get_unique", - ModelDeltaTrackerInputTestParams( - embedding_config_type=EmbeddingConfig, - embedding_tables=[ - EmbeddingTableProps( - embedding_table_config=EmbeddingConfig( - name="table_fqn_1", - num_embeddings=NUM_EMBEDDINGS, - embedding_dim=EMBEDDING_DIM, - feature_names=["f1", "f2", "f3"], - ), - sharding=ShardingType.ROW_WISE, - ), - ], - model_tracker_config=ModelTrackerConfig(), + delta_tracker_config=DeltaTrackerConfig(), ), TrackerNotInitOutputTestParams( - dmp_tracker_atter="get_unique", + dmp_tracker_atter="get_delta_tracker", ), ), ] @@ -511,7 +494,7 @@ def test_tracker_not_initialized( sharding=ShardingType.ROW_WISE, ), ], - model_tracker_config=ModelTrackerConfig( + delta_tracker_config=DeltaTrackerConfig( tracking_mode=TrackingMode.ID_ONLY, delete_on_read=True, ), @@ -560,7 +543,7 @@ def test_tracker_not_initialized( sharding=ShardingType.ROW_WISE, ), ], - model_tracker_config=ModelTrackerConfig( + delta_tracker_config=DeltaTrackerConfig( tracking_mode=TrackingMode.ID_ONLY, delete_on_read=True, ), @@ -607,7 +590,7 @@ def test_tracker_not_initialized( sharding=ShardingType.ROW_WISE, ), ], - model_tracker_config=ModelTrackerConfig( + delta_tracker_config=DeltaTrackerConfig( tracking_mode=TrackingMode.ID_ONLY, delete_on_read=True, ), @@ -650,7 +633,7 @@ def test_tracker_not_initialized( sharding=ShardingType.ROW_WISE, ), ], - model_tracker_config=ModelTrackerConfig( + delta_tracker_config=DeltaTrackerConfig( tracking_mode=TrackingMode.ID_ONLY, delete_on_read=True, ), @@ -712,7 +695,7 @@ def test_tracker_id_mode( sharding=ShardingType.ROW_WISE, ), ], - model_tracker_config=ModelTrackerConfig( + delta_tracker_config=DeltaTrackerConfig( tracking_mode=TrackingMode.EMBEDDING, delete_on_read=True, ), @@ -760,7 +743,7 @@ def test_tracker_id_mode( sharding=ShardingType.ROW_WISE, ), ], - model_tracker_config=ModelTrackerConfig( + delta_tracker_config=DeltaTrackerConfig( tracking_mode=TrackingMode.EMBEDDING, delete_on_read=True, ), @@ -805,7 +788,7 @@ def test_tracker_id_mode( sharding=ShardingType.ROW_WISE, ), ], - model_tracker_config=ModelTrackerConfig( + delta_tracker_config=DeltaTrackerConfig( tracking_mode=TrackingMode.EMBEDDING, delete_on_read=True, ), @@ -878,7 +861,7 @@ def test_tracker_embedding_mode( sharding=ShardingType.ROW_WISE, ), ], - model_tracker_config=ModelTrackerConfig( + delta_tracker_config=DeltaTrackerConfig( tracking_mode=TrackingMode.EMBEDDING, delete_on_read=True, ), @@ -972,7 +955,7 @@ def test_tracker_embedding_mode( sharding=ShardingType.ROW_WISE, ), ], - model_tracker_config=ModelTrackerConfig( + delta_tracker_config=DeltaTrackerConfig( tracking_mode=TrackingMode.ID_ONLY, delete_on_read=True, ), @@ -1084,7 +1067,7 @@ def test_multiple_get( sharding=ShardingType.ROW_WISE, ), ], - model_tracker_config=ModelTrackerConfig( + delta_tracker_config=DeltaTrackerConfig( tracking_mode=TrackingMode.ID_ONLY, delete_on_read=True, consumers=["A", "B"], @@ -1181,7 +1164,7 @@ def test_multiple_get( sharding=ShardingType.ROW_WISE, ), ], - model_tracker_config=ModelTrackerConfig( + delta_tracker_config=DeltaTrackerConfig( tracking_mode=TrackingMode.ID_ONLY, delete_on_read=False, consumers=["A", "B"], @@ -1278,7 +1261,7 @@ def test_multiple_get( sharding=ShardingType.ROW_WISE, ), ], - model_tracker_config=ModelTrackerConfig( + delta_tracker_config=DeltaTrackerConfig( tracking_mode=TrackingMode.ID_ONLY, delete_on_read=True, consumers=["A", "B"], @@ -1375,7 +1358,7 @@ def test_multiple_get( sharding=ShardingType.ROW_WISE, ), ], - model_tracker_config=ModelTrackerConfig( + delta_tracker_config=DeltaTrackerConfig( tracking_mode=TrackingMode.ID_ONLY, delete_on_read=False, consumers=["A", "B"], @@ -1483,7 +1466,7 @@ def test_multiple_consumers( sharding=ShardingType.ROW_WISE, ), ], - model_tracker_config=ModelTrackerConfig( + delta_tracker_config=DeltaTrackerConfig( tracking_mode=TrackingMode.MOMENTUM_LAST, delete_on_read=True, ), @@ -1522,7 +1505,7 @@ def test_multiple_consumers( sharding=ShardingType.ROW_WISE, ), ], - model_tracker_config=ModelTrackerConfig( + delta_tracker_config=DeltaTrackerConfig( tracking_mode=TrackingMode.MOMENTUM_LAST, delete_on_read=True, ), @@ -1578,7 +1561,7 @@ def test_duplication_with_momentum( sharding=ShardingType.ROW_WISE, ), ], - model_tracker_config=ModelTrackerConfig( + delta_tracker_config=DeltaTrackerConfig( tracking_mode=TrackingMode.MOMENTUM_DIFF, delete_on_read=True, ), @@ -1617,7 +1600,7 @@ def test_duplication_with_momentum( sharding=ShardingType.ROW_WISE, ), ], - model_tracker_config=ModelTrackerConfig( + delta_tracker_config=DeltaTrackerConfig( tracking_mode=TrackingMode.ROWWISE_ADAGRAD, delete_on_read=True, ), @@ -1675,10 +1658,11 @@ def _test_fqn_to_feature_names( ctx=ctx, embedding_config_type=input_params.embedding_config_type, tables=input_params.embedding_tables, - config=input_params.model_tracker_config, + config=input_params.delta_tracker_config, ) - dt = dt_model.get_model_tracker() + dt = dt_model.get_delta_tracker() + assert isinstance(dt, ModelDeltaTrackerTrec) unittest.TestCase().assertEqual( dt.fqn_to_feature_names(), output_params.expected_fqn_to_feature_names ) @@ -1704,11 +1688,9 @@ def _test_tracker_init( tables=input_params.embedding_tables, config=None, ) - with unittest.TestCase().assertRaisesRegex( - AssertionError, - "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init.", - ): - getattr(dt_model, output_params.dmp_tracker_atter)() + unittest.TestCase().assertEqual( + getattr(dt_model, output_params.dmp_tracker_atter)(), None + ) def _test_id_mode( @@ -1728,10 +1710,11 @@ def _test_id_mode( ctx=ctx, embedding_config_type=test_params.embedding_config_type, tables=test_params.embedding_tables, - config=test_params.model_tracker_config, + config=test_params.delta_tracker_config, ) features_list = model_input_generator(test_params.model_inputs, rank) - dt = dt_model.get_model_tracker() + dt = dt_model.get_delta_tracker() + assert isinstance(dt, ModelDeltaTrackerTrec) for features in features_list: tracked_out = dt_model(features) baseline_out = baseline_model(features) @@ -1811,7 +1794,7 @@ def _test_embedding_mode( ctx=ctx, embedding_config_type=test_params.embedding_config_type, tables=test_params.embedding_tables, - config=test_params.model_tracker_config, + config=test_params.delta_tracker_config, ) else: dt_model, baseline_model = get_models( @@ -1820,12 +1803,13 @@ def _test_embedding_mode( ctx=ctx, embedding_config_type=test_params.embedding_config_type, tables=test_params.embedding_tables, - config=test_params.model_tracker_config, + config=test_params.delta_tracker_config, ) # Only proceed with the rest of the test if models were created successfully features_list = model_input_generator(test_params.model_inputs, rank) - dt = dt_model.get_model_tracker() + dt = dt_model.get_delta_tracker() + assert isinstance(dt, ModelDeltaTrackerTrec) orig_emb1 = ( # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ec`. @@ -1935,10 +1919,11 @@ def _test_multiple_get( ctx=ctx, embedding_config_type=test_params.embedding_config_type, tables=test_params.embedding_tables, - config=test_params.model_tracker_config, + config=test_params.delta_tracker_config, ) features_list = model_input_generator(test_params.model_inputs, rank) - dt = dt_model.get_model_tracker() + dt = dt_model.get_delta_tracker() + assert isinstance(dt, ModelDeltaTrackerTrec) table_fqns = dt.fqn_to_feature_names().keys() table_fqns_list = list(table_fqns) expected_emb1 = torch.tensor([]) @@ -2015,10 +2000,11 @@ def _test_multiple_consumer( ctx=ctx, embedding_config_type=test_params.embedding_config_type, tables=test_params.embedding_tables, - config=test_params.model_tracker_config, + config=test_params.delta_tracker_config, ) features_list = model_input_generator(test_params.model_inputs, rank) - dt = dt_model.get_model_tracker() + dt = dt_model.get_delta_tracker() + assert isinstance(dt, ModelDeltaTrackerTrec) table_fqns = dt.fqn_to_feature_names().keys() table_fqns_list = list(table_fqns) @@ -2076,12 +2062,13 @@ def _test_duplication_with_momentum( ctx=ctx, embedding_config_type=test_params.embedding_config_type, tables=test_params.embedding_tables, - config=test_params.model_tracker_config, + config=test_params.delta_tracker_config, ) dt_model_opt = torch.optim.Adam(dt_model.parameters(), lr=0.1) baseline_opt = torch.optim.Adam(baseline_model.parameters(), lr=0.1) features_list = model_input_generator(test_params.model_inputs, rank) - dt = dt_model.get_model_tracker() + dt = dt_model.get_delta_tracker() + assert isinstance(dt, ModelDeltaTrackerTrec) table_fqns = dt.fqn_to_feature_names().keys() table_fqns_list = list(table_fqns) for features in features_list: @@ -2119,7 +2106,7 @@ def _test_duplication_with_rowwise_adagrad( ctx=ctx, embedding_config_type=test_params.embedding_config_type, tables=test_params.embedding_tables, - config=test_params.model_tracker_config, + config=test_params.delta_tracker_config, optimizer_type=OptimType.EXACT_ROWWISE_ADAGRAD, ) @@ -2146,7 +2133,8 @@ def _test_duplication_with_rowwise_adagrad( baseline_opt = torch.optim.Adam(baseline_model.parameters(), lr=0.1) features_list = model_input_generator(test_params.model_inputs, rank) - dt = dt_model.get_model_tracker() + dt = dt_model.get_delta_tracker() + assert isinstance(dt, ModelDeltaTrackerTrec) table_fqns = dt.fqn_to_feature_names().keys() table_fqns_list = list(table_fqns) diff --git a/torchrec/distributed/model_tracker/trackers/__init__.py b/torchrec/distributed/model_tracker/trackers/__init__.py new file mode 100644 index 000000000..07a1ae891 --- /dev/null +++ b/torchrec/distributed/model_tracker/trackers/__init__.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +"""MPZCH Raw ID Tracker +""" + +from torchrec.distributed.model_tracker.trackers.raw_id_tracker import ( # noqa + RawIdTracker, +) diff --git a/torchrec/distributed/model_tracker/trackers/raw_id_tracker.py b/torchrec/distributed/model_tracker/trackers/raw_id_tracker.py new file mode 100644 index 000000000..42eeb90e9 --- /dev/null +++ b/torchrec/distributed/model_tracker/trackers/raw_id_tracker.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +# pyre-strict + +import logging +from collections import Counter, OrderedDict +from typing import Dict, Iterable, List, Optional, Tuple + +import torch + +from torch import nn +from torchrec.distributed.embedding_types import ( + KeyedJaggedTensor, + ShardedEmbeddingTable, +) +from torchrec.distributed.mc_embeddingbag import ( + ShardedManagedCollisionEmbeddingBagCollection, +) +from torchrec.distributed.mc_modules import ShardedManagedCollisionCollection +from torchrec.distributed.model_tracker.delta_store import DeltaStoreTrec + +from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTracker +from torchrec.distributed.model_tracker.types import IndexedLookup, UniqueRows + +logger: logging.Logger = logging.getLogger(__name__) + +SUPPORTED_MODULES = (ShardedManagedCollisionCollection,) + + +class RawIdTracker(ModelDeltaTracker): + def __init__( + self, + model: nn.Module, + delete_on_read: bool = True, + fqns_to_skip: Iterable[str] = (), + ) -> None: + self._model = model + self._consumers: Optional[List[str]] = None + self._delete_on_read = delete_on_read + self._fqn_to_feature_map: Dict[str, List[str]] = {} + self._fqns_to_skip: Iterable[str] = fqns_to_skip + + self.curr_batch_idx: int = 0 + self.curr_compact_index: int = 0 + + # from module FQN to SUPPORTED_MODULES + self.tracked_modules: Dict[str, nn.Module] = {} + self.table_to_fqn: Dict[str, str] = {} + self.feature_to_fqn: Dict[str, str] = {} + # Generate the mapping from FQN to feature names. + self.fqn_to_feature_names() + # Validate is the mode is supported for the given module and initialize tracker functions + self._validate_and_init_tracker_fns() + # init TBE tracker wrapper and register consumer ids + self._init_tbe_tracker_wrapper(self._model) + + # per_consumer_batch_idx is used to track the batch index for each consumer. + # This is used to retrieve the delta values for a given consumer as well as + # start_ids for compaction window. + + # Note: For raw id tracking, this has to be assigned after the _init_tbe_tracker_wrapper() + # call as _init_tbe_tracker_wrapper is setting up consumers for TBEs + + self.per_consumer_batch_idx: Dict[str, int] = { + c: -1 for c in (self._consumers or [self.DEFAULT_CONSUMER]) + } + + self.store: DeltaStoreTrec = DeltaStoreTrec() + + # Mapping feature name to corresponding FQNs. This is used for retrieving + # the FQN associated with a given feature name in record_lookup(). + for fqn, feature_names in self._fqn_to_feature_map.items(): + for feature_name in feature_names: + if feature_name in self.feature_to_fqn: + logger.warning( + f"Duplicate feature name: {feature_name} in fqn {fqn}" + ) + continue + self.feature_to_fqn[feature_name] = fqn + logger.info(f"feature_to_fqn: {self.feature_to_fqn}") + + def step(self) -> None: + # Move batch index forward for all consumers. + self.curr_batch_idx += 1 + + def _should_skip_fqn(self, fqn: str) -> bool: + split_fqn = fqn.split(".") + # Skipping partial FQNs present in fqns_to_skip + # TODO: Validate if we need to support more complex patterns for skipping fqns + should_skip = False + for fqn_to_skip in self._fqns_to_skip: + if fqn_to_skip in split_fqn: + logger.info(f"Skipping {fqn} because it is part of fqns_to_skip") + should_skip = True + break + return should_skip + + def _should_track_table( + self, embedding_tables: List[ShardedEmbeddingTable] + ) -> bool: + should_track = True + for table_config in embedding_tables: + for fqn_to_skip in self._fqns_to_skip: + if fqn_to_skip in table_config.name: + should_track = False + break + return should_track + + def fqn_to_feature_names(self) -> Dict[str, List[str]]: + """ + Returns a mapping of FQN to feature names from all Supported Modules [EmbeddingCollection and EmbeddingBagCollection] present in the given model. + """ + if (self._fqn_to_feature_map is not None) and len(self._fqn_to_feature_map) > 0: + return self._fqn_to_feature_map + + table_to_feature_names: Dict[str, List[str]] = OrderedDict() + for fqn, named_module in self._model.named_modules(): + if self._should_skip_fqn(fqn): + continue + # Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states. + if isinstance(named_module, SUPPORTED_MODULES): + should_track_module = True + for table_name, config in named_module._table_name_to_config.items(): + for fqn_to_skip in self._fqns_to_skip: + if fqn_to_skip in table_name: + should_track_module = False + logger.info( + f"Found {table_name} for {fqn} with features {config.feature_names} should_track_module: {should_track_module}" + ) + table_to_feature_names[table_name] = config.feature_names + if should_track_module: + self.tracked_modules[self._clean_fqn_fn(fqn)] = named_module + for table_name in table_to_feature_names: + # Using the split FQN to get the exact table name matching. Otherwise, checking "table_name in fqn" + # will incorrectly match fqn with all the table names that have the same prefix + split_fqn = fqn.split(".") + if table_name in split_fqn: + embedding_fqn = self._clean_fqn_fn(fqn) + if table_name in self.table_to_fqn: + # Sanity check for validating that we don't have more then one table mapping to same fqn. + logger.warning( + f"Override {self.table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}" + ) + self.table_to_fqn[table_name] = embedding_fqn + logger.info(f"Table to fqn: {self.table_to_fqn}") + flatten_names = [ + name for names in table_to_feature_names.values() for name in names + ] + # TODO: Validate if there is a better way to handle duplicate feature names. + # Logging a warning if duplicate feature names are found across tables, but continue execution as this could be a valid case. + if len(set(flatten_names)) != len(flatten_names): + counts = Counter(flatten_names) + duplicates = [item for item, count in counts.items() if count > 1] + logger.warning(f"duplicate feature names found: {duplicates}") + + fqn_to_feature_names: Dict[str, List[str]] = OrderedDict() + for table_name in table_to_feature_names: + if table_name not in self.table_to_fqn: + # This is likely unexpected, where we can't locate the FQN associated with this table. + logger.warning( + f"Table {table_name} not found in {self.table_to_fqn}, skipping" + ) + continue + fqn_to_feature_names[self.table_to_fqn[table_name]] = ( + table_to_feature_names[table_name] + ) + self._fqn_to_feature_map = fqn_to_feature_names + return fqn_to_feature_names + + def record_lookup( + self, + kjt: KeyedJaggedTensor, + states: torch.Tensor, + emb_module: Optional[nn.Module] = None, + raw_ids: Optional[torch.Tensor] = None, + ) -> None: + per_table_ids: Dict[str, List[torch.Tensor]] = {} + per_table_raw_ids: Dict[str, List[torch.Tensor]] = {} + + # Skip storing invalid input or raw ids + if ( + raw_ids is None + or (kjt.values().numel() == 0) + or not (raw_ids.numel() % kjt.values().numel() == 0) + ): + return + + embeddings_2d = raw_ids.view(kjt.values().numel(), -1) + + offset: int = 0 + for key in kjt.keys(): + table_fqn = self.table_to_fqn[key] + ids_list: List[torch.Tensor] = per_table_ids.get(table_fqn, []) + emb_list: List[torch.Tensor] = per_table_raw_ids.get(table_fqn, []) + + ids = kjt[key].values() + ids_list.append(ids) + emb_list.append(embeddings_2d[offset : offset + ids.numel()]) + offset += ids.numel() + + per_table_ids[table_fqn] = ids_list + per_table_raw_ids[table_fqn] = emb_list + + for table_fqn, ids_list in per_table_ids.items(): + self.store.append( + batch_idx=self.curr_batch_idx, + fqn=table_fqn, + ids=torch.cat(ids_list), + states=None, + raw_ids=torch.cat(per_table_raw_ids[table_fqn]), + ) + + def _clean_fqn_fn(self, fqn: str) -> str: + # strip FQN prefixes added by DMP and other TorchRec operations to match state dict FQN + # handles both "_dmp_wrapped_module.module." and "module." prefixes + prefixes_to_strip = ["_dmp_wrapped_module.module.", "module."] + for prefix in prefixes_to_strip: + if fqn.startswith(prefix): + return fqn[len(prefix) :] + return fqn + + def _validate_and_init_tracker_fns(self) -> None: + "To validate the mode is supported for the given module" + for module in self.tracked_modules.values(): + if isinstance(module, SUPPORTED_MODULES): + # register post lookup function + module.register_post_lookup_tracker_fn(self.record_lookup) + + def _init_tbe_tracker_wrapper(self, module: nn.Module) -> None: + for fqn, named_module in self._model.named_modules(): + if self._should_skip_fqn(fqn): + continue + if isinstance(named_module, ShardedManagedCollisionEmbeddingBagCollection): + for lookup in named_module._embedding_module._lookups: + # pyre-ignore + for emb in lookup._emb_modules: + # Only initialize tracker for TBEs that contain tables we want to track + should_track_table = self._should_track_table( + emb._config.embedding_tables + ) + if should_track_table: + emb.init_raw_id_tracker( + self.get_indexed_lookups, + self.delete, + ) + if self._consumers is None: + self._consumers = [] + self._consumers.append(emb._emb_module.uuid) + + def get_unique_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]: + return {} + + def get_unique( + self, + consumer: Optional[str] = None, + top_percentage: Optional[float] = 1.0, + per_table_percentage: Optional[Dict[str, Tuple[float, str]]] = None, + sorted_by_indices: Optional[bool] = True, + ) -> Dict[str, UniqueRows]: + return {} + + def clear(self, consumer: Optional[str] = None) -> None: + pass + + def get_indexed_lookups( + self, + tables: List[str], + consumer: Optional[str] = None, + ) -> Dict[str, List[torch.Tensor]]: + raw_id_per_table: Dict[str, List[torch.Tensor]] = {} + consumer = consumer or self.DEFAULT_CONSUMER + assert ( + consumer in self.per_consumer_batch_idx + ), f"consumer {consumer} not present in {self.per_consumer_batch_idx.values()}" + + index_end: int = self.curr_batch_idx + 1 + index_start = self.per_consumer_batch_idx[consumer] + indexed_lookups = {} + if index_start < index_end: + self.per_consumer_batch_idx[consumer] = index_end + indexed_lookups = self.store.get_indexed_lookups(index_start, index_end) + + for table in tables: + raw_ids_list = [] + fqn = self.table_to_fqn[table] + if fqn in indexed_lookups: + for indexed_lookup in indexed_lookups[fqn]: + if indexed_lookup.raw_ids is not None: + raw_ids_list.append(indexed_lookup.raw_ids) + raw_id_per_table[table] = raw_ids_list + + if self._delete_on_read: + self.store.delete(up_to_idx=min(self.per_consumer_batch_idx.values())) + + return raw_id_per_table + + def delete(self, up_to_idx: Optional[int]) -> None: + self.store.delete(up_to_idx) diff --git a/torchrec/distributed/model_tracker/types.py b/torchrec/distributed/model_tracker/types.py index 3fbb70063..f279f40e1 100644 --- a/torchrec/distributed/model_tracker/types.py +++ b/torchrec/distributed/model_tracker/types.py @@ -23,6 +23,7 @@ class IndexedLookup: batch_idx: int ids: torch.Tensor states: Optional[torch.Tensor] + raw_ids: Optional[torch.Tensor] = None compact: bool = False @@ -74,8 +75,36 @@ class UpdateMode(Enum): LAST = "last" +class Trackers(Enum): + r""" + Supported Tracker in TorchRec + + Enums: + DeltaTracker: Generic Tracker for EC and EBC which tracks ids/states configured througs modes + RawIdTracker: Specialized tracker for MPZCH for tracking Raw ids + """ + + DELTA_TRACKER = "delta_tracker" + RAW_ID_TRACKER = "raw_id_tracker" + + @dataclass -class ModelTrackerConfig: +class RawIdTrackerConfig: + r""" + Configuration for ``RawIdTracker``. + + Args: + delete_on_read (bool): whether to delete the compacted data after get_delta method is called. + fqns_to_skip (List[str]): list of FQNs to skip tracking. + + """ + + delete_on_read: bool = True + fqns_to_skip: List[str] = field(default_factory=list) + + +@dataclass +class DeltaTrackerConfig: r""" Configuration for ``ModelDeltaTracker``. @@ -83,6 +112,8 @@ class ModelTrackerConfig: tracking_mode (TrackingMode): tracking mode for the delta tracker. consumers (Optional[List[str]]): list of consumers for the delta tracker. delete_on_read (bool): whether to delete the compacted data after get_delta method is called. + fqns_to_skip (List[str]): list of FQNs to skip tracking. + """ @@ -91,3 +122,16 @@ class ModelTrackerConfig: delete_on_read: bool = True auto_compact: bool = False fqns_to_skip: List[str] = field(default_factory=list) + + +@dataclass +class ModelTrackerConfigs: + r""" + Configuration for ``ModelTracker Implementations``. + + Args: + RawIdTrackerConfig + + """ + + raw_id_tracker_config: Optional[RawIdTrackerConfig] = None