File tree Expand file tree Collapse file tree 5 files changed +0
-21
lines changed Expand file tree Collapse file tree 5 files changed +0
-21
lines changed Original file line number Diff line number Diff line change @@ -565,7 +565,6 @@ def _group_tables_per_rank(
565
565
),
566
566
_prefetch_and_cached (table ),
567
567
table .use_virtual_table if is_inference else None ,
568
- table .enable_embedding_update ,
569
568
)
570
569
# micromanage the order of we traverse the groups to ensure backwards compatibility
571
570
if grouping_key not in groups :
@@ -582,7 +581,6 @@ def _group_tables_per_rank(
582
581
_ ,
583
582
_ ,
584
583
use_virtual_table ,
585
- enable_embedding_update ,
586
584
) = grouping_key
587
585
grouped_tables = groups [grouping_key ]
588
586
# remove non-native fused params
@@ -604,7 +602,6 @@ def _group_tables_per_rank(
604
602
compute_kernel = compute_kernel_type ,
605
603
embedding_tables = grouped_tables ,
606
604
fused_params = per_tbe_fused_params ,
607
- enable_embedding_update = enable_embedding_update ,
608
605
)
609
606
)
610
607
return grouped_embedding_configs
Original file line number Diff line number Diff line change @@ -251,7 +251,6 @@ class GroupedEmbeddingConfig:
251
251
compute_kernel : EmbeddingComputeKernel
252
252
embedding_tables : List [ShardedEmbeddingTable ]
253
253
fused_params : Optional [Dict [str , Any ]] = None
254
- enable_embedding_update : bool = False
255
254
256
255
def feature_hash_sizes (self ) -> List [int ]:
257
256
feature_hash_sizes = []
Original file line number Diff line number Diff line change @@ -223,7 +223,6 @@ def _shard(
223
223
total_num_buckets = info .embedding_config .total_num_buckets ,
224
224
use_virtual_table = info .embedding_config .use_virtual_table ,
225
225
virtual_table_eviction_policy = info .embedding_config .virtual_table_eviction_policy ,
226
- enable_embedding_update = info .embedding_config .enable_embedding_update ,
227
226
)
228
227
)
229
228
return tables_per_rank
@@ -279,20 +278,6 @@ def _get_feature_hash_sizes(self) -> List[int]:
279
278
feature_hash_sizes .extend (group_config .feature_hash_sizes ())
280
279
return feature_hash_sizes
281
280
282
- def _get_num_writable_features (self ) -> int :
283
- return sum (
284
- group_config .num_features ()
285
- for group_config in self ._grouped_embedding_configs
286
- if group_config .enable_embedding_update
287
- )
288
-
289
- def _get_writable_feature_hash_sizes (self ) -> List [int ]:
290
- feature_hash_sizes : List [int ] = []
291
- for group_config in self ._grouped_embedding_configs :
292
- if group_config .enable_embedding_update :
293
- feature_hash_sizes .extend (group_config .feature_hash_sizes ())
294
- return feature_hash_sizes
295
-
296
281
297
282
class RwSparseFeaturesDist (BaseSparseFeaturesDist [KeyedJaggedTensor ]):
298
283
"""
Original file line number Diff line number Diff line change @@ -370,7 +370,6 @@ class BaseEmbeddingConfig:
370
370
total_num_buckets : Optional [int ] = None
371
371
use_virtual_table : bool = False
372
372
virtual_table_eviction_policy : Optional [VirtualTableEvictionPolicy ] = None
373
- enable_embedding_update : bool = False
374
373
375
374
def get_weight_init_max (self ) -> float :
376
375
if self .weight_init_max is None :
Original file line number Diff line number Diff line change @@ -43,7 +43,6 @@ class StableEmbeddingBagConfig:
43
43
total_num_buckets : Optional [int ] = None
44
44
use_virtual_table : bool = False
45
45
virtual_table_eviction_policy : Optional [VirtualTableEvictionPolicy ] = None
46
- enable_embedding_update : bool = False
47
46
pooling : PoolingType = PoolingType .SUM
48
47
49
48
You can’t perform that action at this time.
0 commit comments