Skip to content

Commit 073d7db

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
improve VBEAlltoAll docstring example
Summary: added information in docstring to make VBE A2A example clearer. the send and receive logic should be much easier to connect now Reviewed By: nipung90 Differential Revision: D82478431
1 parent eac316e commit 073d7db

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

torchrec/distributed/dist_data.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -866,23 +866,46 @@ class VariableBatchPooledEmbeddingsAllToAll(nn.Module):
866866
Example::
867867
868868
kjt_split = [1, 2]
869+
870+
# the kjt_split informs the number of features owned by each rank, here t0 owns f0 and
871+
# t1 owns f1 and f2.
872+
869873
emb_dim_per_rank_per_feature = [[2], [3, 3]]
870874
a2a = VariableBatchPooledEmbeddingsAllToAll(
871875
pg, emb_dim_per_rank_per_feature, device
872876
)
873877
874878
t0 = torch.rand(6) # 2 * (2 + 1)
875-
t1 = torch.rand(24) # 3 * (1 + 3) + 3 * (2 + 2)
879+
t1 = torch.rand(24) # 3 * (1 + 2) + 3 * (3 + 2)
880+
881+
# t0 and t1 are the flattened send buffers of pooled embedding outputs produced on the
882+
# ranks that own the features, computed as embedding_dim * (sum of variable batch sizes
883+
# for that feature across all source ranks), summed over the features owned by that destination rank.
884+
876885
# r0_batch_size r1_batch_size
877886
# f_0: 2 1
878887
-----------------------------------------
879888
# f_1: 1 2
880889
# f_2: 3 2
890+
891+
# batch_size_per_rank_per_feature tensor is specified from the perspective of the sending rank
892+
# outer_index = destination rank, inner vector = features ownwed by the sending rank (in emb_dim_per_rank_per_feature order)
893+
881894
r0_batch_size_per_rank_per_feature = [[2], [1]]
882895
r1_batch_size_per_rank_per_feature = [[1, 3], [2, 2]]
896+
897+
# r0 wants r1 wants
898+
# f0: 2 1
899+
# f1: 1 2
900+
# f2: 3 2
901+
# which informs the per_feature_pre_a2a vectors
902+
883903
r0_batch_size_per_feature_pre_a2a = [2, 1, 3]
884904
r1_batch_size_per_feature_pre_a2a = [1, 2, 2]
885905
906+
# r0 should recieve f0: 2 (from r0), f1: 1 (from r1), f2: 3 (from r1)
907+
# r1 should recieve f0: 1 (from r0), f1: 2 (from r1), f2: 2 (from r1)
908+
886909
rank0_output = a2a(
887910
t0, r0_batch_size_per_rank_per_feature, r0_batch_size_per_feature_pre_a2a
888911
).wait()

0 commit comments

Comments
 (0)