File tree Expand file tree Collapse file tree 5 files changed +22
-0
lines changed Expand file tree Collapse file tree 5 files changed +22
-0
lines changed Original file line number Diff line number Diff line change @@ -565,6 +565,7 @@ def _group_tables_per_rank(
565565 ),
566566 _prefetch_and_cached (table ),
567567 table .use_virtual_table if is_inference else None ,
568+ table .enable_embedding_update ,
568569 )
569570 # micromanage the order of we traverse the groups to ensure backwards compatibility
570571 if grouping_key not in groups :
@@ -581,6 +582,7 @@ def _group_tables_per_rank(
581582 _ ,
582583 _ ,
583584 use_virtual_table ,
585+ enable_embedding_update ,
584586 ) = grouping_key
585587 grouped_tables = groups [grouping_key ]
586588 # remove non-native fused params
@@ -602,6 +604,7 @@ def _group_tables_per_rank(
602604 compute_kernel = compute_kernel_type ,
603605 embedding_tables = grouped_tables ,
604606 fused_params = per_tbe_fused_params ,
607+ enable_embedding_update = enable_embedding_update ,
605608 )
606609 )
607610 return grouped_embedding_configs
Original file line number Diff line number Diff line change @@ -251,6 +251,8 @@ class GroupedEmbeddingConfig:
251251 compute_kernel : EmbeddingComputeKernel
252252 embedding_tables : List [ShardedEmbeddingTable ]
253253 fused_params : Optional [Dict [str , Any ]] = None
254+ # Write-enabled Embedding Tables cannot be grouped with read-only Embedding Tables TBE needs to be separate.
255+ enable_embedding_update : bool = False
254256
255257 def feature_hash_sizes (self ) -> List [int ]:
256258 feature_hash_sizes = []
Original file line number Diff line number Diff line change @@ -223,6 +223,7 @@ def _shard(
223223 total_num_buckets = info .embedding_config .total_num_buckets ,
224224 use_virtual_table = info .embedding_config .use_virtual_table ,
225225 virtual_table_eviction_policy = info .embedding_config .virtual_table_eviction_policy ,
226+ enable_embedding_update = info .embedding_config .enable_embedding_update ,
226227 )
227228 )
228229 return tables_per_rank
@@ -278,6 +279,20 @@ def _get_feature_hash_sizes(self) -> List[int]:
278279 feature_hash_sizes .extend (group_config .feature_hash_sizes ())
279280 return feature_hash_sizes
280281
282+ def _get_num_writable_features (self ) -> int :
283+ return sum (
284+ group_config .num_features ()
285+ for group_config in self ._grouped_embedding_configs
286+ if group_config .enable_embedding_update
287+ )
288+
289+ def _get_writable_feature_hash_sizes (self ) -> List [int ]:
290+ feature_hash_sizes : List [int ] = []
291+ for group_config in self ._grouped_embedding_configs :
292+ if group_config .enable_embedding_update :
293+ feature_hash_sizes .extend (group_config .feature_hash_sizes ())
294+ return feature_hash_sizes
295+
281296
282297class RwSparseFeaturesDist (BaseSparseFeaturesDist [KeyedJaggedTensor ]):
283298 """
Original file line number Diff line number Diff line change @@ -370,6 +370,7 @@ class BaseEmbeddingConfig:
370370 total_num_buckets : Optional [int ] = None
371371 use_virtual_table : bool = False
372372 virtual_table_eviction_policy : Optional [VirtualTableEvictionPolicy ] = None
373+ enable_embedding_update : bool = False
373374
374375 def get_weight_init_max (self ) -> float :
375376 if self .weight_init_max is None :
Original file line number Diff line number Diff line change @@ -43,6 +43,7 @@ class StableEmbeddingBagConfig:
4343 total_num_buckets : Optional [int ] = None
4444 use_virtual_table : bool = False
4545 virtual_table_eviction_policy : Optional [VirtualTableEvictionPolicy ] = None
46+ enable_embedding_update : bool = False
4647 pooling : PoolingType = PoolingType .SUM
4748
4849
You can’t perform that action at this time.
0 commit comments