@@ -472,18 +472,14 @@ def __init__(
472
472
if qcomm_codecs_registry
473
473
else None
474
474
)
475
- self ._intra_dist : Optional [
476
- Union [
477
- PooledEmbeddingsReduceScatter ,
478
- VariableBatchPooledEmbeddingsReduceScatter ,
479
- ]
480
- ] = None
481
- self ._cross_dist : Optional [
482
- Union [
483
- PooledEmbeddingsAllToAll ,
484
- VariableBatchPooledEmbeddingsAllToAll ,
485
- ]
475
+ self ._intra_dist : Optional [PooledEmbeddingsReduceScatter ] = None
476
+ self ._cross_dist : Optional [PooledEmbeddingsAllToAll ] = None
477
+ self ._variable_intra_dist : Optional [
478
+ VariableBatchPooledEmbeddingsReduceScatter
486
479
] = None
480
+ self ._variable_cross_dist : Optional [VariableBatchPooledEmbeddingsAllToAll ] = (
481
+ None
482
+ )
487
483
488
484
def forward (
489
485
self ,
@@ -514,13 +510,15 @@ def forward(
514
510
sharding_ctx .batch_size_per_rank_per_feature ,
515
511
)
516
512
rs_result = cast (
517
- VariableBatchPooledEmbeddingsReduceScatter , self ._intra_dist
513
+ VariableBatchPooledEmbeddingsReduceScatter , self ._variable_intra_dist
518
514
)(
519
515
local_embs ,
520
516
batch_size_per_rank_per_feature = batch_size_per_feature_sum_by_cross_group ,
521
517
embedding_dims = self ._emb_dim_per_node_per_feature [current_node ],
522
518
).wait ()
523
- return cast (VariableBatchPooledEmbeddingsAllToAll , self ._cross_dist )(
519
+ return cast (
520
+ VariableBatchPooledEmbeddingsAllToAll , self ._variable_cross_dist
521
+ )(
524
522
rs_result ,
525
523
batch_size_per_rank_per_feature = batch_size_per_rank_per_feature_by_cross_group [
526
524
local_rank
@@ -615,28 +613,27 @@ def _create_output_dist_modules(
615
613
self , sharding_ctx : Optional [EmbeddingShardingContext ] = None
616
614
) -> None :
617
615
if sharding_ctx is not None and sharding_ctx .variable_batch_per_feature :
618
- self ._intra_dist = VariableBatchPooledEmbeddingsReduceScatter (
616
+ self ._variable_intra_dist = VariableBatchPooledEmbeddingsReduceScatter (
619
617
pg = self ._intra_pg ,
620
618
codecs = self ._intra_codecs ,
621
619
)
622
- self ._cross_dist = VariableBatchPooledEmbeddingsAllToAll (
620
+ self ._variable_cross_dist = VariableBatchPooledEmbeddingsAllToAll (
623
621
pg = self ._cross_pg ,
624
622
emb_dim_per_rank_per_feature = self ._emb_dim_per_node_per_feature ,
625
623
device = self ._device ,
626
624
callbacks = None , # don't pass permute callback, handle in LazyAwaitable
627
625
codecs = self ._cross_codecs ,
628
626
)
629
- else :
630
- self ._intra_dist = PooledEmbeddingsReduceScatter (
631
- pg = self ._intra_pg ,
632
- codecs = self ._intra_codecs ,
633
- )
634
- self ._cross_dist = PooledEmbeddingsAllToAll (
635
- pg = self ._cross_pg ,
636
- dim_sum_per_rank = self ._dim_sum_per_node ,
637
- device = self ._device ,
638
- codecs = self ._cross_codecs ,
639
- )
627
+ self ._intra_dist = PooledEmbeddingsReduceScatter (
628
+ pg = self ._intra_pg ,
629
+ codecs = self ._intra_codecs ,
630
+ )
631
+ self ._cross_dist = PooledEmbeddingsAllToAll (
632
+ pg = self ._cross_pg ,
633
+ dim_sum_per_rank = self ._dim_sum_per_node ,
634
+ device = self ._device ,
635
+ codecs = self ._cross_codecs ,
636
+ )
640
637
641
638
642
639
class TwRwPooledEmbeddingSharding (
0 commit comments