Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1605,7 +1605,7 @@ def compute_and_output_dist(
):
embs = lookup(features)
if self.post_lookup_tracker_fn is not None:
self.post_lookup_tracker_fn(features, embs, self)
self.post_lookup_tracker_fn(features, embs, self, None)

with maybe_annotate_embedding_event(
EmbeddingEvent.OUTPUT_DIST, self._module_fqn, sharding_type
Expand Down
18 changes: 16 additions & 2 deletions torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,15 @@ def __init__(
self._lookups: List[nn.Module] = []
self._output_dists: List[nn.Module] = []
self.post_lookup_tracker_fn: Optional[
Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None]
Callable[
[
KeyedJaggedTensor,
torch.Tensor,
Optional[nn.Module],
Optional[torch.Tensor],
],
None,
]
] = None
self.post_odist_tracker_fn: Optional[Callable[..., None]] = None

Expand Down Expand Up @@ -445,7 +453,13 @@ def train(self, mode: bool = True): # pyre-ignore[3]
def register_post_lookup_tracker_fn(
self,
record_fn: Callable[
[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None
[
KeyedJaggedTensor,
torch.Tensor,
Optional[nn.Module],
Optional[torch.Tensor],
],
None,
],
) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1671,7 +1671,7 @@ def compute_and_output_dist(
):
embs = lookup(features)
if self.post_lookup_tracker_fn is not None:
self.post_lookup_tracker_fn(features, embs, self)
self.post_lookup_tracker_fn(features, embs, self, None)

with maybe_annotate_embedding_event(
EmbeddingEvent.OUTPUT_DIST,
Expand Down
24 changes: 22 additions & 2 deletions torchrec/distributed/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,15 @@ def __init__(
self._use_index_dedup = use_index_dedup
self._initialize_torch_state()
self.post_lookup_tracker_fn: Optional[
Callable[[KeyedJaggedTensor, torch.Tensor], None]
Callable[
[
KeyedJaggedTensor,
torch.Tensor,
Optional[nn.Module],
Optional[torch.Tensor],
],
None,
]
] = None

def _initialize_torch_state(self) -> None:
Expand Down Expand Up @@ -756,6 +764,8 @@ def compute(
if self.post_lookup_tracker_fn is not None:
self.post_lookup_tracker_fn(
KeyedJaggedTensor.from_jt_dict(mc_input),
torch.empty(0),
None,
mcm._hash_zch_identities.index_select(
dim=0, index=mc_input[table].values()
),
Expand All @@ -782,6 +792,8 @@ def compute(
if self.post_lookup_tracker_fn is not None:
self.post_lookup_tracker_fn(
KeyedJaggedTensor.from_jt_dict(mc_input),
torch.empty(0),
None,
mcm._hash_zch_identities.index_select(dim=0, index=values),
)

Expand Down Expand Up @@ -876,7 +888,15 @@ def unsharded_module_type(self) -> Type[ManagedCollisionCollection]:

def register_post_lookup_tracker_fn(
self,
record_fn: Callable[[KeyedJaggedTensor, torch.Tensor], None],
record_fn: Callable[
[
KeyedJaggedTensor,
torch.Tensor,
Optional[nn.Module],
Optional[torch.Tensor],
],
None,
],
) -> None:
"""
Register a function to be called after lookup is done. This is used for
Expand Down
94 changes: 55 additions & 39 deletions torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,18 @@
from torch.nn.modules.module import _IncompatibleKeys
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.comm import get_local_size
from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTrackerTrec
from torchrec.distributed.model_tracker.types import ModelTrackerConfig, UniqueRows
from torchrec.distributed.model_tracker.model_delta_tracker import (
ModelDeltaTracker,
ModelDeltaTrackerTrec,
)
from torchrec.distributed.model_tracker.trackers.raw_id_tracker import RawIdTracker
from torchrec.distributed.model_tracker.types import (
DeltaTrackerConfig,
ModelTrackerConfigs,
RawIdTrackerConfig,
Trackers,
UniqueRows,
)

from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.sharding_plan import get_default_sharders
Expand All @@ -53,6 +63,7 @@
none_throws,
sharded_model_copy,
)

from torchrec.optim.fused import FusedOptimizerModule
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer

Expand Down Expand Up @@ -240,7 +251,7 @@ def __init__(
init_data_parallel: bool = True,
init_parameters: bool = True,
data_parallel_wrapper: Optional[DataParallelWrapper] = None,
model_tracker_config: Optional[ModelTrackerConfig] = None,
model_tracker_configs: Optional[ModelTrackerConfigs] = None,
) -> None:
super().__init__()
torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}")
Expand Down Expand Up @@ -294,11 +305,18 @@ def __init__(
if init_data_parallel:
self.init_data_parallel()

self.model_delta_tracker: Optional[ModelDeltaTrackerTrec] = (
self._init_delta_tracker(model_tracker_config, self._dmp_wrapped_module)
if model_tracker_config is not None
else None
)
self.model_trackers: Dict[str, ModelDeltaTracker] = {}

if (
model_tracker_configs is not None
and model_tracker_configs.raw_id_tracker_config is not None
):
self.model_trackers[Trackers.RAW_ID_TRACKER.name] = (
self._init_raw_id_tracker(
model_tracker_configs.raw_id_tracker_config,
self._dmp_wrapped_module,
)
)

@property
def module(self) -> nn.Module:
Expand All @@ -321,14 +339,14 @@ def module(self, value: nn.Module) -> None:

# pyre-ignore [2, 3]
def forward(self, *args, **kwargs) -> Any:
if self.model_delta_tracker is not None:
for tracker in self.model_trackers.values():
# The step() call advances the internal batch counter so that subsequent ID tracking and delta
# retrieval operations can be properly organized by batch boundaries.

# Context: ModelDeltaTracker tracks unique embedding IDs (and optionally embeddings / states)
# useful for calculating topk rows for checkpointing and for updating fresh embedding weights
# between predictors and trainers in online training scenarios.
self.model_delta_tracker.step()
tracker.step()

# Model forward for DMP wrapped model
return self._dmp_wrapped_module(*args, **kwargs)
Expand Down Expand Up @@ -369,16 +387,25 @@ def _init_dmp(self, module: nn.Module) -> nn.Module:
return self._shard_modules_impl(module)

def _init_delta_tracker(
self, model_tracker_config: ModelTrackerConfig, module: nn.Module
) -> ModelDeltaTrackerTrec:
self, delta_tracker_config: DeltaTrackerConfig, module: nn.Module
) -> ModelDeltaTracker:
# Init delta tracker if config is provided
return ModelDeltaTrackerTrec(
model=module,
consumers=model_tracker_config.consumers,
delete_on_read=model_tracker_config.delete_on_read,
auto_compact=model_tracker_config.auto_compact,
mode=model_tracker_config.tracking_mode,
fqns_to_skip=model_tracker_config.fqns_to_skip,
consumers=delta_tracker_config.consumers,
delete_on_read=delta_tracker_config.delete_on_read,
auto_compact=delta_tracker_config.auto_compact,
mode=delta_tracker_config.tracking_mode,
fqns_to_skip=delta_tracker_config.fqns_to_skip,
)

def _init_raw_id_tracker(
self, raw_id_tracker_config: RawIdTrackerConfig, module: nn.Module
) -> RawIdTracker:
return RawIdTracker(
model=module,
delete_on_read=raw_id_tracker_config.delete_on_read,
fqns_to_skip=raw_id_tracker_config.fqns_to_skip,
)

def _init_optim(self, module: nn.Module) -> CombinedOptimizer:
Expand Down Expand Up @@ -458,36 +485,25 @@ def init_parameters(module: nn.Module) -> None:
module.apply(init_parameters)

def init_torchrec_delta_tracker(
self, model_tracker_config: ModelTrackerConfig
) -> ModelDeltaTrackerTrec:
self, delta_tracker_config: DeltaTrackerConfig
) -> ModelDeltaTracker:
"""
Initializes the model delta tracker if it doesn't exists.
"""
if self.model_delta_tracker is None:
self.model_delta_tracker = self._init_delta_tracker(
model_tracker_config, self._dmp_wrapped_module
if Trackers.DELTA_TRACKER.name not in self.model_trackers:
self.model_trackers[Trackers.DELTA_TRACKER.name] = self._init_delta_tracker(
delta_tracker_config, self._dmp_wrapped_module
)

return none_throws(self.model_delta_tracker)

def get_model_tracker(self) -> ModelDeltaTrackerTrec:
"""
Returns the model tracker if it exists.
"""

assert (
self.model_delta_tracker is not None
), "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init."
return self.model_delta_tracker
return self.model_trackers[Trackers.DELTA_TRACKER.name]

def get_unique(self, consumer: Optional[str] = None) -> Dict[str, UniqueRows]:
def get_delta_tracker(self) -> Optional[ModelDeltaTracker]:
"""
Returns the delta rows for the given consumer.
Returns the delta tracker if it exists.
"""
assert (
self.model_delta_tracker is not None
), "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init."
return self.model_delta_tracker.get_unique(consumer)
if Trackers.DELTA_TRACKER.name in self.model_trackers:
return self.model_trackers[Trackers.DELTA_TRACKER.name]
return None

def sparse_grad_parameter_names(
self, destination: Optional[List[str]] = None, prefix: str = ""
Expand Down
3 changes: 2 additions & 1 deletion torchrec/distributed/model_tracker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
)
from torchrec.distributed.model_tracker.types import (
IndexedLookup, # noqa
ModelTrackerConfig, # noqa
ModelTrackerConfigs, # noqa
Trackers, # noqa
TrackingMode, # noqa
UniqueRows, # noqa
UpdateMode, # noqa
Expand Down
18 changes: 17 additions & 1 deletion torchrec/distributed/model_tracker/delta_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def append(
fqn: str,
ids: torch.Tensor,
states: Optional[torch.Tensor],
raw_ids: Optional[torch.Tensor] = None,
) -> None:
"""
Append a batch of ids and states to the store for a specific table.
Expand Down Expand Up @@ -162,10 +163,11 @@ def append(
fqn: str,
ids: torch.Tensor,
states: Optional[torch.Tensor],
raw_ids: Optional[torch.Tensor] = None,
) -> None:
table_fqn_lookup = self.per_fqn_lookups.get(fqn, [])
table_fqn_lookup.append(
IndexedLookup(batch_idx=batch_idx, ids=ids, states=states)
IndexedLookup(batch_idx=batch_idx, ids=ids, states=states, raw_ids=raw_ids)
)
self.per_fqn_lookups[fqn] = table_fqn_lookup

Expand Down Expand Up @@ -224,6 +226,20 @@ def compact(self, start_idx: int, end_idx: int) -> None:
)
self.per_fqn_lookups = new_per_fqn_lookups

def get_indexed_lookups(
self, start_idx: int, end_idx: int
) -> Dict[str, List[IndexedLookup]]:
r"""
Return all unique/delta ids per table from the Delta Store.
"""
per_fqn_lookups: Dict[str, List[IndexedLookup]] = {}
for table_fqn, lookups in self.per_fqn_lookups.items():
indexices = [h.batch_idx for h in lookups]
index_l = bisect_left(indexices, start_idx)
index_r = bisect_left(indexices, end_idx)
per_fqn_lookups[table_fqn] = lookups[index_l:index_r]
return per_fqn_lookups

def get_unique(self, from_idx: int = 0) -> Dict[str, UniqueRows]:
r"""
Return all unique/delta ids per table from the Delta Store.
Expand Down
9 changes: 9 additions & 0 deletions torchrec/distributed/model_tracker/model_delta_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def record_lookup(
kjt: KeyedJaggedTensor,
states: torch.Tensor,
emb_module: Optional[nn.Module] = None,
raw_ids: Optional[torch.Tensor] = None,
) -> None:
"""
Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states.
Expand Down Expand Up @@ -131,6 +132,13 @@ def clear(self, consumer: Optional[str] = None) -> None:
"""
pass

@abstractmethod
def step(self) -> None:
"""
Advance the batch index for all consumers.
"""
pass


class ModelDeltaTrackerTrec(ModelDeltaTracker):
r"""
Expand Down Expand Up @@ -244,6 +252,7 @@ def record_lookup(
kjt: KeyedJaggedTensor,
states: torch.Tensor,
emb_module: Optional[nn.Module] = None,
raw_ids: Optional[torch.Tensor] = None,
) -> None:
"""
Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states.
Expand Down
Loading
Loading