|
17 | 17 | from math import sqrt |
18 | 18 | from typing import ( |
19 | 19 | Any, |
| 20 | + Callable, |
20 | 21 | cast, |
21 | 22 | Dict, |
22 | 23 | Generic, |
|
70 | 71 | GroupedEmbeddingConfig, |
71 | 72 | ShardedEmbeddingTable, |
72 | 73 | ) |
| 74 | +from torchrec.distributed.model_tracker.types import IndexedLookup |
73 | 75 | from torchrec.distributed.shards_wrapper import LocalShardsWrapper |
74 | 76 | from torchrec.distributed.types import ( |
75 | 77 | Shard, |
|
80 | 82 | TensorProperties, |
81 | 83 | ) |
82 | 84 | from torchrec.distributed.utils import append_prefix, none_throws |
| 85 | + |
83 | 86 | from torchrec.modules.embedding_configs import ( |
84 | 87 | CountBasedEvictionPolicy, |
85 | 88 | CountTimestampMixedEvictionPolicy, |
|
97 | 100 | ) |
98 | 101 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor |
99 | 102 |
|
| 103 | + |
100 | 104 | logger: logging.Logger = logging.getLogger(__name__) |
101 | 105 |
|
102 | 106 | RES_ENABLED_TABLES_STR = "res_enabled_tables" |
103 | 107 | RES_STORE_SHARDS_STR = "res_store_shards" |
104 | 108 | ENABLE_RAW_EMBEDDING_STREAMING_STR = "enable_raw_embedding_streaming" |
105 | 109 |
|
106 | 110 |
|
| 111 | +class RawIdTrackerWrapper: |
| 112 | + def __init__( |
| 113 | + self, |
| 114 | + get_indexed_lookups: Callable[ |
| 115 | + [List[str], Optional[str]], |
| 116 | + Dict[str, List[torch.Tensor]], |
| 117 | + ], |
| 118 | + delete: Callable[ |
| 119 | + [int], |
| 120 | + None, |
| 121 | + ], |
| 122 | + ) -> None: |
| 123 | + self.get_indexed_lookups = get_indexed_lookups |
| 124 | + self.delete = delete |
| 125 | + |
| 126 | + |
107 | 127 | def _populate_res_params(config: GroupedEmbeddingConfig) -> Tuple[bool, RESParams]: |
108 | 128 | # populate res_params, which is used for raw embedding streaming |
109 | 129 | # here only populates the params available in fused_params and TBE configs |
@@ -2526,6 +2546,7 @@ def __init__( |
2526 | 2546 | self._lengths_per_emb: List[int] = [] |
2527 | 2547 | self.table_name_to_count: Dict[str, int] = {} |
2528 | 2548 | self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {} |
| 2549 | + self._raw_id_tracker_wrapper: Optional[RawIdTrackerWrapper] = None |
2529 | 2550 |
|
2530 | 2551 | for idx, table_config in enumerate(self._config.embedding_tables): |
2531 | 2552 | self._local_rows.append(table_config.local_rows) |
@@ -2668,6 +2689,22 @@ def named_parameters_by_table( |
2668 | 2689 | for name, param in self._param_per_table.items(): |
2669 | 2690 | yield name, param |
2670 | 2691 |
|
| 2692 | + def init_raw_id_tracker( |
| 2693 | + self, |
| 2694 | + get_indexed_lookups: Callable[ |
| 2695 | + [List[str], Optional[str]], |
| 2696 | + Dict[str, List[torch.Tensor]], |
| 2697 | + ], |
| 2698 | + delete: Callable[ |
| 2699 | + [int], |
| 2700 | + None, |
| 2701 | + ], |
| 2702 | + ) -> None: |
| 2703 | + if isinstance(self._emb_module, SplitTableBatchedEmbeddingBagsCodegen): |
| 2704 | + self._raw_id_tracker_wrapper = RawIdTrackerWrapper( |
| 2705 | + get_indexed_lookups, delete |
| 2706 | + ) |
| 2707 | + |
2671 | 2708 |
|
2672 | 2709 | class KeyValueEmbeddingBag(BaseBatchedEmbeddingBag[torch.Tensor], FusedOptimizerModule): |
2673 | 2710 | def __init__( |
|
0 commit comments