Skip to content

Commit 7cf5f60

Browse files
Pooja Agarwalfacebook-github-bot
authored andcommitted
Revert D81366596: Add configs for write dist
Differential Revision: D81366596 Original commit changeset: 2cc120644d31 Original Phabricator Diff: D81366596 fbshipit-source-id: 9ca96c29db0efa9da04335cb4e2c021e2c445600
1 parent 212efe1 commit 7cf5f60

File tree

5 files changed

+0
-21
lines changed

5 files changed

+0
-21
lines changed

torchrec/distributed/embedding_sharding.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,6 @@ def _group_tables_per_rank(
565565
),
566566
_prefetch_and_cached(table),
567567
table.use_virtual_table if is_inference else None,
568-
table.enable_embedding_update,
569568
)
570569
# micromanage the order of we traverse the groups to ensure backwards compatibility
571570
if grouping_key not in groups:
@@ -582,7 +581,6 @@ def _group_tables_per_rank(
582581
_,
583582
_,
584583
use_virtual_table,
585-
enable_embedding_update,
586584
) = grouping_key
587585
grouped_tables = groups[grouping_key]
588586
# remove non-native fused params
@@ -604,7 +602,6 @@ def _group_tables_per_rank(
604602
compute_kernel=compute_kernel_type,
605603
embedding_tables=grouped_tables,
606604
fused_params=per_tbe_fused_params,
607-
enable_embedding_update=enable_embedding_update,
608605
)
609606
)
610607
return grouped_embedding_configs

torchrec/distributed/embedding_types.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,6 @@ class GroupedEmbeddingConfig:
251251
compute_kernel: EmbeddingComputeKernel
252252
embedding_tables: List[ShardedEmbeddingTable]
253253
fused_params: Optional[Dict[str, Any]] = None
254-
enable_embedding_update: bool = False
255254

256255
def feature_hash_sizes(self) -> List[int]:
257256
feature_hash_sizes = []

torchrec/distributed/sharding/rw_sharding.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,6 @@ def _shard(
223223
total_num_buckets=info.embedding_config.total_num_buckets,
224224
use_virtual_table=info.embedding_config.use_virtual_table,
225225
virtual_table_eviction_policy=info.embedding_config.virtual_table_eviction_policy,
226-
enable_embedding_update=info.embedding_config.enable_embedding_update,
227226
)
228227
)
229228
return tables_per_rank
@@ -279,20 +278,6 @@ def _get_feature_hash_sizes(self) -> List[int]:
279278
feature_hash_sizes.extend(group_config.feature_hash_sizes())
280279
return feature_hash_sizes
281280

282-
def _get_num_writable_features(self) -> int:
283-
return sum(
284-
group_config.num_features()
285-
for group_config in self._grouped_embedding_configs
286-
if group_config.enable_embedding_update
287-
)
288-
289-
def _get_writable_feature_hash_sizes(self) -> List[int]:
290-
feature_hash_sizes: List[int] = []
291-
for group_config in self._grouped_embedding_configs:
292-
if group_config.enable_embedding_update:
293-
feature_hash_sizes.extend(group_config.feature_hash_sizes())
294-
return feature_hash_sizes
295-
296281

297282
class RwSparseFeaturesDist(BaseSparseFeaturesDist[KeyedJaggedTensor]):
298283
"""

torchrec/modules/embedding_configs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,6 @@ class BaseEmbeddingConfig:
370370
total_num_buckets: Optional[int] = None
371371
use_virtual_table: bool = False
372372
virtual_table_eviction_policy: Optional[VirtualTableEvictionPolicy] = None
373-
enable_embedding_update: bool = False
374373

375374
def get_weight_init_max(self) -> float:
376375
if self.weight_init_max is None:

torchrec/schema/api_tests/test_embedding_config_schema.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ class StableEmbeddingBagConfig:
4343
total_num_buckets: Optional[int] = None
4444
use_virtual_table: bool = False
4545
virtual_table_eviction_policy: Optional[VirtualTableEvictionPolicy] = None
46-
enable_embedding_update: bool = False
4746
pooling: PoolingType = PoolingType.SUM
4847

4948

0 commit comments

Comments
 (0)