@@ -866,23 +866,46 @@ class VariableBatchPooledEmbeddingsAllToAll(nn.Module):
866
866
Example::
867
867
868
868
kjt_split = [1, 2]
869
+
870
+ # the kjt_split informs the number of features owned by each rank, here t0 owns f0 and
871
+ # t1 owns f1 and f2.
872
+
869
873
emb_dim_per_rank_per_feature = [[2], [3, 3]]
870
874
a2a = VariableBatchPooledEmbeddingsAllToAll(
871
875
pg, emb_dim_per_rank_per_feature, device
872
876
)
873
877
874
878
t0 = torch.rand(6) # 2 * (2 + 1)
875
- t1 = torch.rand(24) # 3 * (1 + 3) + 3 * (2 + 2)
879
+ t1 = torch.rand(24) # 3 * (1 + 2) + 3 * (3 + 2)
880
+
881
+ # t0 and t1 are the flattened send buffers of pooled embedding outputs produced on the
882
+ # ranks that own the features, computed as embedding_dim * (sum of variable batch sizes
883
+ # for that feature across all source ranks), summed over the features owned by that destination rank.
884
+
876
885
# r0_batch_size r1_batch_size
877
886
# f_0: 2 1
878
887
-----------------------------------------
879
888
# f_1: 1 2
880
889
# f_2: 3 2
890
+
891
+ # batch_size_per_rank_per_feature tensor is specified from the perspective of the sending rank
892
+ # outer_index = destination rank, inner vector = features ownwed by the sending rank (in emb_dim_per_rank_per_feature order)
893
+
881
894
r0_batch_size_per_rank_per_feature = [[2], [1]]
882
895
r1_batch_size_per_rank_per_feature = [[1, 3], [2, 2]]
896
+
897
+ # r0 wants r1 wants
898
+ # f0: 2 1
899
+ # f1: 1 2
900
+ # f2: 3 2
901
+ # which informs the per_feature_pre_a2a vectors
902
+
883
903
r0_batch_size_per_feature_pre_a2a = [2, 1, 3]
884
904
r1_batch_size_per_feature_pre_a2a = [1, 2, 2]
885
905
906
+ # r0 should recieve f0: 2 (from r0), f1: 1 (from r1), f2: 3 (from r1)
907
+ # r1 should recieve f0: 1 (from r0), f1: 2 (from r1), f2: 2 (from r1)
908
+
886
909
rank0_output = a2a(
887
910
t0, r0_batch_size_per_rank_per_feature, r0_batch_size_per_feature_pre_a2a
888
911
).wait()
0 commit comments