diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index 41ad28f0548..00adfaeb985 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -293,6 +293,40 @@ void lowerToReduceScatter( backend)); } +void lowerToAllToAll( + TensorView* input_tv, + TensorView* output_tv, + const CommunicatorBackend backend, + std::vector& comms) { + const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); + const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); + NVF_ERROR_EQ( + sender_mesh.rank(), + 1, + "AllToAll sender mesh must be a 1D mesh. Given ", + sender_mesh); + NVF_ERROR_EQ( + receiver_mesh.rank(), + 1, + "AllToAll receiver mesh must be a 1D mesh. Given ", + receiver_mesh); + NVF_ERROR_EQ( + sender_mesh, + receiver_mesh, + "AllToAll sender and receiver meshes must be the same. Given ", + sender_mesh, + " and ", + receiver_mesh); + comms.push_back(IrBuilder::create( + CommunicationType::AllToAll, + output_tv, + input_tv, + sender_mesh.vector(), + /*root=*/-1, + c10d::ReduceOp::RedOpType::UNUSED, + backend)); +} + IterDomain* getLogicalFromLoopId(TensorView* tv, IterDomain* loop_id) { std::unordered_set logical_ids = getInputsInTargetDomain({loop_id}, tv->getLogicalDomain()); @@ -379,8 +413,14 @@ CommunicationInfo getCommunicationInfo(Expr* e) { IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did); IterDomain* c_logical_id = getLogicalFromLoopId(consumer, c_loop_did); // TODO(#4604): This is problematic for 2D sharding. - fill_communication_info( - CommunicationType::SendRecv, p_logical_id, c_logical_id); + + if (c_logical_id == p2c_map.at(p_logical_id)) { + fill_communication_info( + CommunicationType::SendRecv, p_logical_id, c_logical_id); + } else { + fill_communication_info( + CommunicationType::AllToAll, nullptr, nullptr); + } } } else { NVF_ERROR(e->isA() || e->isA()); @@ -441,10 +481,12 @@ Layout getCommunicationLayout( // For the following communication types, the sharded_id does not have to be // outermost in allocation domain. Nonetheless, `tv` still needs to be // contiguous and therefore .contiguous() at the beginning of this function. + // TODO(prmishra): Fix the layout for AllToAll. if (type == CommunicationType::Reduce || type == CommunicationType::Allreduce || type == CommunicationType::Broadcast || - type == CommunicationType::SendRecv) { + type == CommunicationType::SendRecv || + type == CommunicationType::AllToAll) { return layout; } @@ -568,6 +610,9 @@ std::vector convertSingleOpToCommunication( case CommunicationType::Reduce: lowerToReduce(input_tv, output_tv, op_type(e), backend, comms); break; + case CommunicationType::AllToAll: + lowerToAllToAll(input_tv, output_tv, backend, comms); + break; } return comms; diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 2a61af07b15..0756fcecc4c 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -54,6 +54,9 @@ std::ostream& operator<<(std::ostream& os, const CommunicationType& type) { case CommunicationType::SendRecv: os << "SendRecv"; break; + case CommunicationType::AllToAll: + os << "AllToAll"; + break; default: NVF_THROW("unrecognized CommunicationType: ", type); } @@ -152,6 +155,7 @@ bool hasRoot(CommunicationType type) { case CommunicationType::Allgather: case CommunicationType::Allreduce: case CommunicationType::ReduceScatter: + case CommunicationType::AllToAll: return false; default: NVF_THROW("unrecognized CommunicationType: ", type); @@ -169,6 +173,7 @@ bool isReduction(CommunicationType type) { case CommunicationType::Scatter: case CommunicationType::Broadcast: case CommunicationType::SendRecv: + case CommunicationType::AllToAll: return false; default: NVF_THROW("unrecognized CommunicationType: ", type); @@ -605,6 +610,71 @@ c10::intrusive_ptr postSendRecv( /*tag=*/0); } } + +c10::intrusive_ptr postAllToAll( + Communication* communication, + DeviceIdxType my_device_index, + c10d::Backend* backend, + at::Tensor input_tensor, + at::Tensor output_tensor) { + NVF_ERROR( + isTvContiguous(communication->in()), + "Input tensor is not contiguous: ", + communication->in(), + " contiguity: ", + communication->in()->domain()->getContiguityString()); + NVF_ERROR( + isTvContiguous(communication->out()), + "Output tensor is not contiguous: ", + communication->out(), + " contiguity: ", + communication->out()->domain()->getContiguityString()); + + // input_tv = [DIDx(d), n/d, m, ...] + // output_tv = [n, DIDx(d), m/d, ...] + // `n`: gathered dimension + // `m`: scattered dimension + // For alltoall correctness, we split `m` and reorder as [DIDx(d), d, n/d, + // m/d, ...] such that alltoall_base splits across the `d` dimension. + + int64_t d = communication->team_size(); + auto input_sizes = input_tensor.sizes(); + + NVF_CHECK( + input_sizes.at(1) % d == 0, + "Scattered dimension must be divisible by the team size"); + + std::vector input_reshape_sizes( + input_sizes.begin(), input_sizes.end()); + input_reshape_sizes.at(1) = d; + input_reshape_sizes.insert( + input_reshape_sizes.begin() + 2, input_sizes.at(1) / d); + auto reshaped_input = input_tensor.reshape(input_reshape_sizes); + + std::vector permute_dims(input_reshape_sizes.size()); + std::iota(permute_dims.begin(), permute_dims.end(), 0); + std::swap(permute_dims[0], permute_dims[1]); + + auto reordered_input = reshaped_input.permute(permute_dims).contiguous(); + + auto flattened_input_tensor = viewAsCompact(reordered_input); + auto flattened_output_tensor = viewAsCompact(output_tensor); + + // alltoall_base requires even splits of the input and output tensors. + auto input_splits = at::tensor_split( + flattened_input_tensor, communication->team_size(), /*dim=*/0); + auto output_splits = at::tensor_split( + flattened_output_tensor, communication->team_size(), /*dim=*/0); + assertBuffersHaveSameSize(input_splits, output_splits); + + std::vector empty_split_sizes; + return backend->alltoall_base( + flattened_output_tensor, + flattened_input_tensor, + empty_split_sizes, + empty_split_sizes, + /*options=*/{}); +} } // namespace c10::intrusive_ptr postSingleCommunication( @@ -691,6 +761,9 @@ c10::intrusive_ptr postSingleCommunication( backend, input_tensor, output_tensor); + case CommunicationType::AllToAll: + return postAllToAll( + communication, my_device_index, backend, input_tensor, output_tensor); default: NVF_THROW("Wrong communication type: ", communication->type()); return nullptr; diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 358504f81de..1a7f1a1cc4c 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -32,7 +32,8 @@ enum class CommunicationType { Allreduce, ReduceScatter, Broadcast, - SendRecv + SendRecv, + AllToAll }; std::ostream& operator<<(std::ostream& os, const CommunicationType& type); diff --git a/tests/python/multidevice/test_multidevice.py b/tests/python/multidevice/test_multidevice.py index ab4930a7c7d..164ab568085 100644 --- a/tests/python/multidevice/test_multidevice.py +++ b/tests/python/multidevice/test_multidevice.py @@ -421,3 +421,74 @@ def test_binary(multidevice_test): (z,) = fd.execute([x_ref.cuda(), y]) torch.testing.assert_close(z, multidevice_test.shard_tensor(z_ref, z_tv)) + + +@pytest.mark.mpi +def test_alltoall(multidevice_test): + d = multidevice_test.size + mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d)) + k, m, n = 3, 40, 56 + + assert ( + m % d == 0 and n % d == 0 + ), "Sharded dimensions must be divisible by the team size" + + with FusionDefinition() as fd: + inp = fd.define_tensor((k, m, n), contiguity=True, dtype=DataType.Half) + # Ideally, we could have exposed and split the allocation domain of + # alltoall input and output, as follows: + # Input of shape [k, m, n] sharded on n + # Output of shape [k, m, n] sharded on m + # + # Where d is an outer split from `n` and d* is an outer_split from `m`: + # input [b, m, DIDx(d), n/d] + # alltoall_input [d*, b, m/d*, DIDx(d), n/d] + # d* is reordered outermost. `alltoall_single` will slice on the + # outermost dimension. + # alltoall_output [d, b, DIDx(d*), m/d*, n/d] + # d is gathered back. AllToAll places data from each device in + # memory in order of rank. + # output [b, DIDx(d*), m/d, n] + # Same as sharding original input along row dimension instead of + # column. + # + # However, tensors corresponding to allocation domains with non-adjacent + # splits fail stride validation. To avoid this, I permute the gathered + # and scattered dimensions to be outermost in that order. `postAllToAll` + # will split the scattered dimension and reorder as + # [DIDx(d), d, n/d, m/d, k] and make it contiguous. The output will be + # [DIDx(d*), n, m, k]. This avoids the non-adjacent split in output and + # avoids exposing the non-adjacent split of `m` in input to the fusion + # definition. I am using `permute` instead of setting the allocation + # domain to avoid another copy. + all2all_inp = fd.ops.permute(inp, dims=[2, 1, 0]) + all2all_out = fd.ops.set(all2all_inp) + out = fd.ops.permute(all2all_out, dims=[2, 1, 0]) + fd.add_output(out) + + inp.set_device_mesh(mesh) + inp.outer_split(2, d) + inp.axis(2).parallelize(nvfuser.ParallelType.mesh_x) + + all2all_inp.set_device_mesh(mesh) + all2all_inp.outer_split(0, d) + all2all_inp.axis(0).parallelize( + nvfuser.ParallelType.mesh_x + ) # [DIDx(d), n, d * m, k] + + all2all_out.set_device_mesh(mesh) + all2all_out.outer_split(1, d) + all2all_out.axis(1).parallelize( + nvfuser.ParallelType.mesh_x + ) # [d * n, DIDx(d), m, k] + all2all_out.set_allocation_domain(all2all_out.get_loop_domain(), True) + + out.set_device_mesh(mesh) + out.outer_split(1, d) + out.axis(1).parallelize(nvfuser.ParallelType.mesh_x) + out.set_allocation_domain(out.get_loop_domain(), True) + + in_tensor = torch.randn(k, m, n, dtype=torch.float16) + sharded = multidevice_test.shard_tensor(in_tensor, inp) + (out,) = fd.execute([sharded]) + torch.testing.assert_close(out.cpu(), multidevice_test.shard_tensor(in_tensor, out))