-
Notifications
You must be signed in to change notification settings - Fork 75
AllToAll implementation #5705
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AllToAll implementation #5705
Conversation
|
Review updated until commit 27fa2d0 Description
|
| Relevant files | |||||||||
|---|---|---|---|---|---|---|---|---|---|
| Enhancement |
| ||||||||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Missing SendRecv enum
|
|
!test |
| // For the following communication types, the sharded_id does not have to be | ||
| // outermost in allocation domain. Nonetheless, `tv` still needs to be | ||
| // contiguous and therefore .contiguous() at the beginning of this function. | ||
| // TODO(prmishra): Fix the layout for AllToAll. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Depending on where we relayout the input/output to be compliant with alltoall requirements, this function and potentially ReorderShardedAxisPass will be affected. I will do it in a following PR once the current PR has been reviewed and we agree on the approach for reordering input/output of alltoall
Greptile SummaryImplements AllToAll collective communication for expert parallelism patterns where sharding dimensions change between producer and consumer tensors. The implementation detects AllToAll when both producer and consumer have sharded dimensions but the logical dimension mapping differs (dimension transpose). The runtime operation uses
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User Code
participant FD as FusionDefinition
participant Detect as getCommunicationInfo
participant Lower as lowerToAllToAll
participant Post as postAllToAll
participant Backend as c10d::Backend
User->>FD: define tensor with mesh_x sharding
User->>FD: add output with different axis sharded
FD->>Detect: analyze producer/consumer sharding
Note over Detect: Check p_sharded && c_sharded
Note over Detect: c_logical_id != p2c_map[p_logical_id]
Detect->>Detect: Identify as AllToAll
Detect->>Lower: lowerToAllToAll(input_tv, output_tv)
Lower->>Lower: Validate sender/receiver meshes are 1D and equal
Lower->>Lower: Create Communication(AllToAll, team=mesh.vector())
Note over Post: At execution time
Post->>Post: Validate input/output contiguity
Post->>Post: viewAsCompact(input_tensor)
Post->>Post: viewAsCompact(output_tensor)
Post->>Post: tensor_split by team_size()
Post->>Post: assertBuffersHaveSameSize
Post->>Backend: alltoall_base(output, input, empty, empty)
Backend-->>Post: Work handle
Post-->>User: Result tensor
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 4 comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 4 comments
csrc/multidevice/communication.cpp
Outdated
| // input_tv = [DIDx(d), n/d, m, ...] | ||
| // output_tv = [n, DIDx(d), m/d, ...] | ||
| // `n`: gathered dimension | ||
| // `m`: scattered dimension | ||
| // For alltoall correctness, we split `m` and reorder as [DIDx(d), d, n/d, | ||
| // m/d, ...] such that alltoall_base splits across the `d` dimension. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment describing the input/output tensor shapes appears to be inconsistent with the actual implementation below.
The comment states:
input_tv = [DIDx(d), n/d, m, ...]m: scattered dimension (at position 2)
But the code at lines 644-651 operates on dimension 1 (input_sizes.at(1)), treating it as the scattered dimension. Based on the test case in test_alltoall, the actual runtime input tensor shape is [n/d, m, k, ...] (after permutation), where:
- Position 0:
n/d(gathered dimension) - Position 1:
m(scattered dimension) - Position 2+: other dimensions
The comment should be updated to reflect the actual runtime tensor layout rather than the logical sharding representation, or clarify that it's describing logical sharding rather than physical memory layout.
| if (c_logical_id == p2c_map.at(p_logical_id)) { | ||
| fill_communication_info( | ||
| CommunicationType::SendRecv, p_logical_id, c_logical_id); | ||
| } else { | ||
| fill_communication_info( | ||
| CommunicationType::AllToAll, nullptr, nullptr); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting both p_sharded_id and c_sharded_id to nullptr for AllToAll communication may cause issues in other parts of the codebase that expect valid IterDomain pointers.
For example, getCommunicationLayout() (lines 476-521) is called with sharded_id as a parameter, and at line 493 it checks posInDomain(layout.allocation_domain(), sharded_id), which could fail if sharded_id is nullptr.
The current workaround is that AllToAll is explicitly handled in getCommunicationLayout() at line 489 to return early, but this creates a tight coupling and fragile dependency. Consider either:
- Storing the actual producer and consumer sharded IDs for AllToAll (even though they map to different logical dimensions)
- Adding explicit nullptr checks in all functions that consume CommunicationInfo
- Documenting this assumption clearly in the CommunicationInfo struct definition
csrc/multidevice/communication.cpp
Outdated
| NVF_CHECK( | ||
| input_sizes.at(1) % d == 0, | ||
| "Scattered dimension must be divisible by the team size"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message says "Scattered dimension must be divisible by the team size" but this check only validates one dimension.
For a complete AllToAll operation, both the scattered dimension (at position 1 in the runtime tensor) and the gathered dimension (at position 0) should be validated:
- The scattered dimension
input_sizes.at(1)must be divisible byd(currently checked ✓) - The gathered dimension should result in an output where the gathered size equals
d * input_sizes.at(0)
Consider adding validation for the output tensor dimensions as well to catch shape mismatches early, rather than relying on assertBuffersHaveSameSize at line 668 which only provides a generic error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 1 comment
| c10::intrusive_ptr<c10d::Work> postAllToAll( | ||
| Communication* communication, | ||
| DeviceIdxType my_device_index, | ||
| c10d::Backend* backend, | ||
| at::Tensor input_tensor, | ||
| at::Tensor output_tensor) { | ||
| NVF_ERROR( | ||
| isTvContiguous(communication->in()), | ||
| "Input tensor is not contiguous: ", | ||
| communication->in(), | ||
| " contiguity: ", | ||
| communication->in()->domain()->getContiguityString()); | ||
| NVF_ERROR( | ||
| isTvContiguous(communication->out()), | ||
| "Output tensor is not contiguous: ", | ||
| communication->out(), | ||
| " contiguity: ", | ||
| communication->out()->domain()->getContiguityString()); | ||
|
|
||
| auto flattened_input_tensor = viewAsCompact(input_tensor); | ||
| auto flattened_output_tensor = viewAsCompact(output_tensor); | ||
|
|
||
| // alltoall_base requires even splits of the input and output tensors. | ||
| auto input_splits = at::tensor_split( | ||
| flattened_input_tensor, communication->team_size(), /*dim=*/0); | ||
| auto output_splits = at::tensor_split( | ||
| flattened_output_tensor, communication->team_size(), /*dim=*/0); | ||
| assertBuffersHaveSameSize(input_splits, output_splits); | ||
|
|
||
| std::vector<int64_t> empty_split_sizes; | ||
| return backend->alltoall_base( | ||
| flattened_output_tensor, | ||
| flattened_input_tensor, | ||
| empty_split_sizes, | ||
| empty_split_sizes, | ||
| /*options=*/{}); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Missing team_size == 1 edge case handling. Other communication functions like postBroadcast (line 347) and postSendRecv (line 582) handle single-device teams by either returning early or doing local copies. AllToAll should handle this case to avoid unnecessary communication overhead and potential runtime errors.
| c10::intrusive_ptr<c10d::Work> postAllToAll( | |
| Communication* communication, | |
| DeviceIdxType my_device_index, | |
| c10d::Backend* backend, | |
| at::Tensor input_tensor, | |
| at::Tensor output_tensor) { | |
| NVF_ERROR( | |
| isTvContiguous(communication->in()), | |
| "Input tensor is not contiguous: ", | |
| communication->in(), | |
| " contiguity: ", | |
| communication->in()->domain()->getContiguityString()); | |
| NVF_ERROR( | |
| isTvContiguous(communication->out()), | |
| "Output tensor is not contiguous: ", | |
| communication->out(), | |
| " contiguity: ", | |
| communication->out()->domain()->getContiguityString()); | |
| auto flattened_input_tensor = viewAsCompact(input_tensor); | |
| auto flattened_output_tensor = viewAsCompact(output_tensor); | |
| // alltoall_base requires even splits of the input and output tensors. | |
| auto input_splits = at::tensor_split( | |
| flattened_input_tensor, communication->team_size(), /*dim=*/0); | |
| auto output_splits = at::tensor_split( | |
| flattened_output_tensor, communication->team_size(), /*dim=*/0); | |
| assertBuffersHaveSameSize(input_splits, output_splits); | |
| std::vector<int64_t> empty_split_sizes; | |
| return backend->alltoall_base( | |
| flattened_output_tensor, | |
| flattened_input_tensor, | |
| empty_split_sizes, | |
| empty_split_sizes, | |
| /*options=*/{}); | |
| } | |
| c10::intrusive_ptr<c10d::Work> postAllToAll( | |
| Communication* communication, | |
| DeviceIdxType my_device_index, | |
| c10d::Backend* backend, | |
| at::Tensor input_tensor, | |
| at::Tensor output_tensor) { | |
| if (communication->team().size() == 1) { | |
| doLocalCopy(output_tensor, input_tensor); | |
| return nullptr; | |
| } | |
| NVF_ERROR( | |
| isTvContiguous(communication->in()), | |
| "Input tensor is not contiguous: ", | |
| communication->in(), | |
| " contiguity: ", | |
| communication->in()->domain()->getContiguityString()); | |
| NVF_ERROR( | |
| isTvContiguous(communication->out()), | |
| "Output tensor is not contiguous: ", | |
| communication->out(), | |
| " contiguity: ", | |
| communication->out()->domain()->getContiguityString()); | |
| auto flattened_input_tensor = viewAsCompact(input_tensor); | |
| auto flattened_output_tensor = viewAsCompact(output_tensor); | |
| // alltoall_base requires even splits of the input and output tensors. | |
| auto input_splits = at::tensor_split( | |
| flattened_input_tensor, communication->team_size(), /*dim=*/0); | |
| auto output_splits = at::tensor_split( | |
| flattened_output_tensor, communication->team_size(), /*dim=*/0); | |
| assertBuffersHaveSameSize(input_splits, output_splits); | |
| std::vector<int64_t> empty_split_sizes; | |
| return backend->alltoall_base( | |
| flattened_output_tensor, | |
| flattened_input_tensor, | |
| empty_split_sizes, | |
| empty_split_sizes, | |
| /*options=*/{}); | |
| } |
wujingyue
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM otherwise
| fill_communication_info( | ||
| CommunicationType::SendRecv, p_logical_id, c_logical_id); | ||
|
|
||
| if (c_logical_id == p2c_map.at(p_logical_id)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As long as tests pass, I'm fine with this condition. But can you tell me the math behind it? If the two DIDs are mapped, isn't that a non-communication?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation of send recv assumes that the input and output are sharded on the same dimension, but have different meshes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is isResharding check at the beginning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it -- we'll have to think about how to represent CollectivePermute for ring-based communication.
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 1 comment
| c10::intrusive_ptr<c10d::Work> postAllToAll( | ||
| Communication* communication, | ||
| DeviceIdxType my_device_index, | ||
| c10d::Backend* backend, | ||
| at::Tensor input_tensor, | ||
| at::Tensor output_tensor) { | ||
| NVF_ERROR( | ||
| isTvContiguous(communication->in()), | ||
| "Input tensor is not contiguous: ", | ||
| communication->in(), | ||
| " contiguity: ", | ||
| communication->in()->domain()->getContiguityString()); | ||
| NVF_ERROR( | ||
| isTvContiguous(communication->out()), | ||
| "Output tensor is not contiguous: ", | ||
| communication->out(), | ||
| " contiguity: ", | ||
| communication->out()->domain()->getContiguityString()); | ||
|
|
||
| auto flattened_input_tensor = viewAsCompact(input_tensor); | ||
| auto flattened_output_tensor = viewAsCompact(output_tensor); | ||
|
|
||
| // alltoall_base requires even splits of the input and output tensors. | ||
| auto input_splits = at::tensor_split( | ||
| flattened_input_tensor, communication->team_size(), /*dim=*/0); | ||
| auto output_splits = at::tensor_split( | ||
| flattened_output_tensor, communication->team_size(), /*dim=*/0); | ||
| assertBuffersHaveSameSize(input_splits, output_splits); | ||
|
|
||
| std::vector<int64_t> empty_split_sizes; | ||
| return backend->alltoall_base( | ||
| flattened_output_tensor, | ||
| flattened_input_tensor, | ||
| empty_split_sizes, | ||
| empty_split_sizes, | ||
| /*options=*/{}); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: missing team_size == 1 edge case - when team size is 1, no communication needed, just local copy (see postBroadcast at line 347 or postSendRecv at line 582)
| c10::intrusive_ptr<c10d::Work> postAllToAll( | |
| Communication* communication, | |
| DeviceIdxType my_device_index, | |
| c10d::Backend* backend, | |
| at::Tensor input_tensor, | |
| at::Tensor output_tensor) { | |
| NVF_ERROR( | |
| isTvContiguous(communication->in()), | |
| "Input tensor is not contiguous: ", | |
| communication->in(), | |
| " contiguity: ", | |
| communication->in()->domain()->getContiguityString()); | |
| NVF_ERROR( | |
| isTvContiguous(communication->out()), | |
| "Output tensor is not contiguous: ", | |
| communication->out(), | |
| " contiguity: ", | |
| communication->out()->domain()->getContiguityString()); | |
| auto flattened_input_tensor = viewAsCompact(input_tensor); | |
| auto flattened_output_tensor = viewAsCompact(output_tensor); | |
| // alltoall_base requires even splits of the input and output tensors. | |
| auto input_splits = at::tensor_split( | |
| flattened_input_tensor, communication->team_size(), /*dim=*/0); | |
| auto output_splits = at::tensor_split( | |
| flattened_output_tensor, communication->team_size(), /*dim=*/0); | |
| assertBuffersHaveSameSize(input_splits, output_splits); | |
| std::vector<int64_t> empty_split_sizes; | |
| return backend->alltoall_base( | |
| flattened_output_tensor, | |
| flattened_input_tensor, | |
| empty_split_sizes, | |
| empty_split_sizes, | |
| /*options=*/{}); | |
| } | |
| c10::intrusive_ptr<c10d::Work> postAllToAll( | |
| Communication* communication, | |
| DeviceIdxType my_device_index, | |
| c10d::Backend* backend, | |
| at::Tensor input_tensor, | |
| at::Tensor output_tensor) { | |
| if (communication->team_size() == 1) { | |
| doLocalCopy(output_tensor, input_tensor); | |
| return nullptr; | |
| } | |
| NVF_ERROR( | |
| isTvContiguous(communication->in()), | |
| "Input tensor is not contiguous: ", | |
| communication->in(), | |
| " contiguity: ", | |
| communication->in()->domain()->getContiguityString()); | |
| NVF_ERROR( | |
| isTvContiguous(communication->out()), | |
| "Output tensor is not contiguous: ", | |
| communication->out(), | |
| " contiguity: ", | |
| communication->out()->domain()->getContiguityString()); | |
| auto flattened_input_tensor = viewAsCompact(input_tensor); | |
| auto flattened_output_tensor = viewAsCompact(output_tensor); | |
| // alltoall_base requires even splits of the input and output tensors. | |
| auto input_splits = at::tensor_split( | |
| flattened_input_tensor, communication->team_size(), /*dim=*/0); | |
| auto output_splits = at::tensor_split( | |
| flattened_output_tensor, communication->team_size(), /*dim=*/0); | |
| assertBuffersHaveSameSize(input_splits, output_splits); | |
| std::vector<int64_t> empty_split_sizes; | |
| return backend->alltoall_base( | |
| flattened_output_tensor, | |
| flattened_input_tensor, | |
| empty_split_sizes, | |
| empty_split_sizes, | |
| /*options=*/{}); | |
| } |
csrc/multidevice/communication.cpp
Outdated
| default: | ||
| NVF_THROW("unrecognized CommunicationType: ", type); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| default: | |
| NVF_THROW("unrecognized CommunicationType: ", type); |
trust -Wswitch
csrc/multidevice/communication.cpp
Outdated
| default: | ||
| NVF_THROW("unrecognized CommunicationType: ", type); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| default: | |
| NVF_THROW("unrecognized CommunicationType: ", type); |
You may need an std::unreachable() after the switch.
| case CommunicationType::AllToAll: | ||
| return false; | ||
| default: | ||
| NVF_THROW("unrecognized CommunicationType: ", type); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 1 comment
| } | ||
| } | ||
|
|
||
| c10::intrusive_ptr<c10d::Work> postAllToAll( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: missing team_size == 1 edge case - when only one device, no communication needed, should return early after local copy (like postBroadcast at line 344)
|
For the igpu failure, update csrc/multidevice/c10d_mock.h |
Thanks! |
Greptile's behavior is changing!From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
|
!build |
No description provided.