Skip to content

Commit 15c3d7a

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
intialize both output dists for VBE in TW/TWRW (#3378)
Summary: Pull Request resolved: #3378 There are rare cases using VBE where one of the KJTs has the same batch size. This is not recognized as a VBE on KJT init which can cause issues in the forward pass. We initialize both output dist comms to support this. Differential Revision: D82478607 fbshipit-source-id: f91e7d1724ae09ff202b4b698a7fc0eedf177e43
1 parent 2b00007 commit 15c3d7a

File tree

2 files changed

+37
-39
lines changed

2 files changed

+37
-39
lines changed

torchrec/distributed/sharding/tw_sharding.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -345,9 +345,8 @@ def __init__(
345345
else None
346346
)
347347
self._emb_dim_per_rank_per_feature = emb_dim_per_rank_per_feature
348-
self._dist: Optional[
349-
Union[PooledEmbeddingsAllToAll, VariableBatchPooledEmbeddingsAllToAll]
350-
] = None
348+
self._dist: Optional[PooledEmbeddingsAllToAll] = None
349+
self._variable_dist: Optional[VariableBatchPooledEmbeddingsAllToAll] = None
351350

352351
def forward(
353352
self,
@@ -371,7 +370,10 @@ def forward(
371370
if sharding_ctx is None:
372371
return cast(PooledEmbeddingsAllToAll, self._dist)(local_embs)
373372
elif sharding_ctx.variable_batch_per_feature:
374-
return cast(VariableBatchPooledEmbeddingsAllToAll, self._dist)(
373+
assert (
374+
self._variable_dist is not None
375+
), "variable batch dist is not initialized!"
376+
return self._variable_dist(
375377
local_embs,
376378
batch_size_per_rank_per_feature=sharding_ctx.batch_size_per_rank_per_feature,
377379
batch_size_per_feature_pre_a2a=sharding_ctx.batch_size_per_feature_pre_a2a,
@@ -386,21 +388,20 @@ def _create_output_dist_module(
386388
self, sharding_ctx: Optional[EmbeddingShardingContext] = None
387389
) -> None:
388390
if sharding_ctx is not None and sharding_ctx.variable_batch_per_feature:
389-
self._dist = VariableBatchPooledEmbeddingsAllToAll(
391+
self._variable_dist = VariableBatchPooledEmbeddingsAllToAll(
390392
pg=self._pg,
391393
emb_dim_per_rank_per_feature=self._emb_dim_per_rank_per_feature,
392394
device=self._device,
393395
callbacks=None,
394396
codecs=self._codecs,
395397
)
396-
else:
397-
self._dist = PooledEmbeddingsAllToAll(
398-
pg=self._pg,
399-
dim_sum_per_rank=self._dim_sum_per_rank,
400-
device=self._device,
401-
callbacks=self._callbacks,
402-
codecs=self._codecs,
403-
)
398+
self._dist = PooledEmbeddingsAllToAll(
399+
pg=self._pg,
400+
dim_sum_per_rank=self._dim_sum_per_rank,
401+
device=self._device,
402+
callbacks=self._callbacks,
403+
codecs=self._codecs,
404+
)
404405

405406

406407
class TwPooledEmbeddingSharding(

torchrec/distributed/sharding/twrw_sharding.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -472,18 +472,14 @@ def __init__(
472472
if qcomm_codecs_registry
473473
else None
474474
)
475-
self._intra_dist: Optional[
476-
Union[
477-
PooledEmbeddingsReduceScatter,
478-
VariableBatchPooledEmbeddingsReduceScatter,
479-
]
480-
] = None
481-
self._cross_dist: Optional[
482-
Union[
483-
PooledEmbeddingsAllToAll,
484-
VariableBatchPooledEmbeddingsAllToAll,
485-
]
475+
self._intra_dist: Optional[PooledEmbeddingsReduceScatter] = None
476+
self._cross_dist: Optional[PooledEmbeddingsAllToAll] = None
477+
self._variable_intra_dist: Optional[
478+
VariableBatchPooledEmbeddingsReduceScatter
486479
] = None
480+
self._variable_cross_dist: Optional[VariableBatchPooledEmbeddingsAllToAll] = (
481+
None
482+
)
487483

488484
def forward(
489485
self,
@@ -514,13 +510,15 @@ def forward(
514510
sharding_ctx.batch_size_per_rank_per_feature,
515511
)
516512
rs_result = cast(
517-
VariableBatchPooledEmbeddingsReduceScatter, self._intra_dist
513+
VariableBatchPooledEmbeddingsReduceScatter, self._variable_intra_dist
518514
)(
519515
local_embs,
520516
batch_size_per_rank_per_feature=batch_size_per_feature_sum_by_cross_group,
521517
embedding_dims=self._emb_dim_per_node_per_feature[current_node],
522518
).wait()
523-
return cast(VariableBatchPooledEmbeddingsAllToAll, self._cross_dist)(
519+
return cast(
520+
VariableBatchPooledEmbeddingsAllToAll, self._variable_cross_dist
521+
)(
524522
rs_result,
525523
batch_size_per_rank_per_feature=batch_size_per_rank_per_feature_by_cross_group[
526524
local_rank
@@ -615,28 +613,27 @@ def _create_output_dist_modules(
615613
self, sharding_ctx: Optional[EmbeddingShardingContext] = None
616614
) -> None:
617615
if sharding_ctx is not None and sharding_ctx.variable_batch_per_feature:
618-
self._intra_dist = VariableBatchPooledEmbeddingsReduceScatter(
616+
self._variable_intra_dist = VariableBatchPooledEmbeddingsReduceScatter(
619617
pg=self._intra_pg,
620618
codecs=self._intra_codecs,
621619
)
622-
self._cross_dist = VariableBatchPooledEmbeddingsAllToAll(
620+
self._variable_cross_dist = VariableBatchPooledEmbeddingsAllToAll(
623621
pg=self._cross_pg,
624622
emb_dim_per_rank_per_feature=self._emb_dim_per_node_per_feature,
625623
device=self._device,
626624
callbacks=None, # don't pass permute callback, handle in LazyAwaitable
627625
codecs=self._cross_codecs,
628626
)
629-
else:
630-
self._intra_dist = PooledEmbeddingsReduceScatter(
631-
pg=self._intra_pg,
632-
codecs=self._intra_codecs,
633-
)
634-
self._cross_dist = PooledEmbeddingsAllToAll(
635-
pg=self._cross_pg,
636-
dim_sum_per_rank=self._dim_sum_per_node,
637-
device=self._device,
638-
codecs=self._cross_codecs,
639-
)
627+
self._intra_dist = PooledEmbeddingsReduceScatter(
628+
pg=self._intra_pg,
629+
codecs=self._intra_codecs,
630+
)
631+
self._cross_dist = PooledEmbeddingsAllToAll(
632+
pg=self._cross_pg,
633+
dim_sum_per_rank=self._dim_sum_per_node,
634+
device=self._device,
635+
codecs=self._cross_codecs,
636+
)
640637

641638

642639
class TwRwPooledEmbeddingSharding(

0 commit comments

Comments
 (0)