-
Notifications
You must be signed in to change notification settings - Fork 75
AllToAll implementation #5705
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AllToAll implementation #5705
Changes from 13 commits
caefdd1
f6c1684
12cc822
0370b89
06ea908
f0763e5
2bb6353
a3c6a47
f0626d8
07b51ad
62109c2
90a2174
347061b
3166c61
31da19a
27fa2d0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -293,6 +293,40 @@ void lowerToReduceScatter( | |
| backend)); | ||
| } | ||
|
|
||
| void lowerToAllToAll( | ||
| TensorView* input_tv, | ||
| TensorView* output_tv, | ||
| const CommunicatorBackend backend, | ||
| std::vector<Expr*>& comms) { | ||
| const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); | ||
| const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); | ||
| NVF_ERROR_EQ( | ||
| sender_mesh.rank(), | ||
| 1, | ||
| "AllToAll sender mesh must be a 1D mesh. Given ", | ||
| sender_mesh); | ||
| NVF_ERROR_EQ( | ||
| receiver_mesh.rank(), | ||
| 1, | ||
| "AllToAll receiver mesh must be a 1D mesh. Given ", | ||
| receiver_mesh); | ||
| NVF_ERROR_EQ( | ||
| sender_mesh, | ||
| receiver_mesh, | ||
| "AllToAll sender and receiver meshes must be the same. Given ", | ||
| sender_mesh, | ||
| " and ", | ||
| receiver_mesh); | ||
| comms.push_back(IrBuilder::create<Communication>( | ||
| CommunicationType::AllToAll, | ||
| output_tv, | ||
| input_tv, | ||
| sender_mesh.vector(), | ||
| /*root=*/-1, | ||
| c10d::ReduceOp::RedOpType::UNUSED, | ||
| backend)); | ||
| } | ||
|
|
||
| IterDomain* getLogicalFromLoopId(TensorView* tv, IterDomain* loop_id) { | ||
| std::unordered_set<IterDomain*> logical_ids = | ||
| getInputsInTargetDomain({loop_id}, tv->getLogicalDomain()); | ||
|
|
@@ -379,8 +413,14 @@ CommunicationInfo getCommunicationInfo(Expr* e) { | |
| IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did); | ||
| IterDomain* c_logical_id = getLogicalFromLoopId(consumer, c_loop_did); | ||
| // TODO(#4604): This is problematic for 2D sharding. | ||
| fill_communication_info( | ||
| CommunicationType::SendRecv, p_logical_id, c_logical_id); | ||
|
|
||
| if (c_logical_id == p2c_map.at(p_logical_id)) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As long as tests pass, I'm fine with this condition. But can you tell me the math behind it? If the two DIDs are mapped, isn't that a non-communication?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The implementation of send recv assumes that the input and output are sharded on the same dimension, but have different meshes.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it -- we'll have to think about how to represent CollectivePermute for ring-based communication. |
||
| fill_communication_info( | ||
| CommunicationType::SendRecv, p_logical_id, c_logical_id); | ||
| } else { | ||
| fill_communication_info( | ||
| CommunicationType::AllToAll, p_logical_id, c_logical_id); | ||
| } | ||
| } | ||
| } else { | ||
| NVF_ERROR(e->isA<ReductionOp>() || e->isA<SqueezeOp>()); | ||
|
|
@@ -441,10 +481,13 @@ Layout getCommunicationLayout( | |
| // For the following communication types, the sharded_id does not have to be | ||
| // outermost in allocation domain. Nonetheless, `tv` still needs to be | ||
| // contiguous and therefore .contiguous() at the beginning of this function. | ||
| // Note: We do not yet reorder for AllToAll and only support cases where the | ||
| // input and output do not require any reordering. | ||
| if (type == CommunicationType::Reduce || | ||
| type == CommunicationType::Allreduce || | ||
| type == CommunicationType::Broadcast || | ||
| type == CommunicationType::SendRecv) { | ||
| type == CommunicationType::SendRecv || | ||
| type == CommunicationType::AllToAll) { | ||
| return layout; | ||
| } | ||
|
|
||
|
|
@@ -568,6 +611,9 @@ std::vector<Expr*> convertSingleOpToCommunication( | |
| case CommunicationType::Reduce: | ||
| lowerToReduce(input_tv, output_tv, op_type(e), backend, comms); | ||
| break; | ||
| case CommunicationType::AllToAll: | ||
| lowerToAllToAll(input_tv, output_tv, backend, comms); | ||
| break; | ||
| } | ||
|
|
||
| return comms; | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -54,6 +54,9 @@ std::ostream& operator<<(std::ostream& os, const CommunicationType& type) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| case CommunicationType::SendRecv: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| os << "SendRecv"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| break; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| case CommunicationType::AllToAll: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| os << "AllToAll"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| break; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| default: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| NVF_THROW("unrecognized CommunicationType: ", type); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| default: | |
| NVF_THROW("unrecognized CommunicationType: ", type); |
trust -Wswitch
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| default: | |
| NVF_THROW("unrecognized CommunicationType: ", type); |
You may need an std::unreachable() after the switch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: missing team_size == 1 edge case - when only one device, no communication needed, should return early after local copy (like postBroadcast at line 344)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Missing team_size == 1 edge case handling. Other communication functions like postBroadcast (line 347) and postSendRecv (line 582) handle single-device teams by either returning early or doing local copies. AllToAll should handle this case to avoid unnecessary communication overhead and potential runtime errors.
| c10::intrusive_ptr<c10d::Work> postAllToAll( | |
| Communication* communication, | |
| DeviceIdxType my_device_index, | |
| c10d::Backend* backend, | |
| at::Tensor input_tensor, | |
| at::Tensor output_tensor) { | |
| NVF_ERROR( | |
| isTvContiguous(communication->in()), | |
| "Input tensor is not contiguous: ", | |
| communication->in(), | |
| " contiguity: ", | |
| communication->in()->domain()->getContiguityString()); | |
| NVF_ERROR( | |
| isTvContiguous(communication->out()), | |
| "Output tensor is not contiguous: ", | |
| communication->out(), | |
| " contiguity: ", | |
| communication->out()->domain()->getContiguityString()); | |
| auto flattened_input_tensor = viewAsCompact(input_tensor); | |
| auto flattened_output_tensor = viewAsCompact(output_tensor); | |
| // alltoall_base requires even splits of the input and output tensors. | |
| auto input_splits = at::tensor_split( | |
| flattened_input_tensor, communication->team_size(), /*dim=*/0); | |
| auto output_splits = at::tensor_split( | |
| flattened_output_tensor, communication->team_size(), /*dim=*/0); | |
| assertBuffersHaveSameSize(input_splits, output_splits); | |
| std::vector<int64_t> empty_split_sizes; | |
| return backend->alltoall_base( | |
| flattened_output_tensor, | |
| flattened_input_tensor, | |
| empty_split_sizes, | |
| empty_split_sizes, | |
| /*options=*/{}); | |
| } | |
| c10::intrusive_ptr<c10d::Work> postAllToAll( | |
| Communication* communication, | |
| DeviceIdxType my_device_index, | |
| c10d::Backend* backend, | |
| at::Tensor input_tensor, | |
| at::Tensor output_tensor) { | |
| if (communication->team().size() == 1) { | |
| doLocalCopy(output_tensor, input_tensor); | |
| return nullptr; | |
| } | |
| NVF_ERROR( | |
| isTvContiguous(communication->in()), | |
| "Input tensor is not contiguous: ", | |
| communication->in(), | |
| " contiguity: ", | |
| communication->in()->domain()->getContiguityString()); | |
| NVF_ERROR( | |
| isTvContiguous(communication->out()), | |
| "Output tensor is not contiguous: ", | |
| communication->out(), | |
| " contiguity: ", | |
| communication->out()->domain()->getContiguityString()); | |
| auto flattened_input_tensor = viewAsCompact(input_tensor); | |
| auto flattened_output_tensor = viewAsCompact(output_tensor); | |
| // alltoall_base requires even splits of the input and output tensors. | |
| auto input_splits = at::tensor_split( | |
| flattened_input_tensor, communication->team_size(), /*dim=*/0); | |
| auto output_splits = at::tensor_split( | |
| flattened_output_tensor, communication->team_size(), /*dim=*/0); | |
| assertBuffersHaveSameSize(input_splits, output_splits); | |
| std::vector<int64_t> empty_split_sizes; | |
| return backend->alltoall_base( | |
| flattened_output_tensor, | |
| flattened_input_tensor, | |
| empty_split_sizes, | |
| empty_split_sizes, | |
| /*options=*/{}); | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: missing team_size == 1 edge case - when team size is 1, no communication needed, just local copy (see postBroadcast at line 347 or postSendRecv at line 582)
| c10::intrusive_ptr<c10d::Work> postAllToAll( | |
| Communication* communication, | |
| DeviceIdxType my_device_index, | |
| c10d::Backend* backend, | |
| at::Tensor input_tensor, | |
| at::Tensor output_tensor) { | |
| NVF_ERROR( | |
| isTvContiguous(communication->in()), | |
| "Input tensor is not contiguous: ", | |
| communication->in(), | |
| " contiguity: ", | |
| communication->in()->domain()->getContiguityString()); | |
| NVF_ERROR( | |
| isTvContiguous(communication->out()), | |
| "Output tensor is not contiguous: ", | |
| communication->out(), | |
| " contiguity: ", | |
| communication->out()->domain()->getContiguityString()); | |
| auto flattened_input_tensor = viewAsCompact(input_tensor); | |
| auto flattened_output_tensor = viewAsCompact(output_tensor); | |
| // alltoall_base requires even splits of the input and output tensors. | |
| auto input_splits = at::tensor_split( | |
| flattened_input_tensor, communication->team_size(), /*dim=*/0); | |
| auto output_splits = at::tensor_split( | |
| flattened_output_tensor, communication->team_size(), /*dim=*/0); | |
| assertBuffersHaveSameSize(input_splits, output_splits); | |
| std::vector<int64_t> empty_split_sizes; | |
| return backend->alltoall_base( | |
| flattened_output_tensor, | |
| flattened_input_tensor, | |
| empty_split_sizes, | |
| empty_split_sizes, | |
| /*options=*/{}); | |
| } | |
| c10::intrusive_ptr<c10d::Work> postAllToAll( | |
| Communication* communication, | |
| DeviceIdxType my_device_index, | |
| c10d::Backend* backend, | |
| at::Tensor input_tensor, | |
| at::Tensor output_tensor) { | |
| if (communication->team_size() == 1) { | |
| doLocalCopy(output_tensor, input_tensor); | |
| return nullptr; | |
| } | |
| NVF_ERROR( | |
| isTvContiguous(communication->in()), | |
| "Input tensor is not contiguous: ", | |
| communication->in(), | |
| " contiguity: ", | |
| communication->in()->domain()->getContiguityString()); | |
| NVF_ERROR( | |
| isTvContiguous(communication->out()), | |
| "Output tensor is not contiguous: ", | |
| communication->out(), | |
| " contiguity: ", | |
| communication->out()->domain()->getContiguityString()); | |
| auto flattened_input_tensor = viewAsCompact(input_tensor); | |
| auto flattened_output_tensor = viewAsCompact(output_tensor); | |
| // alltoall_base requires even splits of the input and output tensors. | |
| auto input_splits = at::tensor_split( | |
| flattened_input_tensor, communication->team_size(), /*dim=*/0); | |
| auto output_splits = at::tensor_split( | |
| flattened_output_tensor, communication->team_size(), /*dim=*/0); | |
| assertBuffersHaveSameSize(input_splits, output_splits); | |
| std::vector<int64_t> empty_split_sizes; | |
| return backend->alltoall_base( | |
| flattened_output_tensor, | |
| flattened_input_tensor, | |
| empty_split_sizes, | |
| empty_split_sizes, | |
| /*options=*/{}); | |
| } |
Uh oh!
There was an error while loading. Please reload this page.