Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions csrc/host_ir/lower_to_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,40 @@ void lowerToReduceScatter(
backend));
}

void lowerToAllToAll(
TensorView* input_tv,
TensorView* output_tv,
const CommunicatorBackend backend,
std::vector<Expr*>& comms) {
const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
NVF_ERROR_EQ(
sender_mesh.rank(),
1,
"AllToAll sender mesh must be a 1D mesh. Given ",
sender_mesh);
NVF_ERROR_EQ(
receiver_mesh.rank(),
1,
"AllToAll receiver mesh must be a 1D mesh. Given ",
receiver_mesh);
NVF_ERROR_EQ(
sender_mesh,
receiver_mesh,
"AllToAll sender and receiver meshes must be the same. Given ",
sender_mesh,
" and ",
receiver_mesh);
comms.push_back(IrBuilder::create<Communication>(
CommunicationType::AllToAll,
output_tv,
input_tv,
sender_mesh.vector(),
/*root=*/-1,
c10d::ReduceOp::RedOpType::UNUSED,
backend));
}

IterDomain* getLogicalFromLoopId(TensorView* tv, IterDomain* loop_id) {
std::unordered_set<IterDomain*> logical_ids =
getInputsInTargetDomain({loop_id}, tv->getLogicalDomain());
Expand Down Expand Up @@ -379,8 +413,14 @@ CommunicationInfo getCommunicationInfo(Expr* e) {
IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did);
IterDomain* c_logical_id = getLogicalFromLoopId(consumer, c_loop_did);
// TODO(#4604): This is problematic for 2D sharding.
fill_communication_info(
CommunicationType::SendRecv, p_logical_id, c_logical_id);

if (c_logical_id == p2c_map.at(p_logical_id)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as tests pass, I'm fine with this condition. But can you tell me the math behind it? If the two DIDs are mapped, isn't that a non-communication?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation of send recv assumes that the input and output are sharded on the same dimension, but have different meshes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is isResharding check at the beginning.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it -- we'll have to think about how to represent CollectivePermute for ring-based communication.

fill_communication_info(
CommunicationType::SendRecv, p_logical_id, c_logical_id);
} else {
fill_communication_info(
CommunicationType::AllToAll, nullptr, nullptr);
}
Comment on lines 417 to 423
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting both p_sharded_id and c_sharded_id to nullptr for AllToAll communication may cause issues in other parts of the codebase that expect valid IterDomain pointers.

For example, getCommunicationLayout() (lines 476-521) is called with sharded_id as a parameter, and at line 493 it checks posInDomain(layout.allocation_domain(), sharded_id), which could fail if sharded_id is nullptr.

The current workaround is that AllToAll is explicitly handled in getCommunicationLayout() at line 489 to return early, but this creates a tight coupling and fragile dependency. Consider either:

  1. Storing the actual producer and consumer sharded IDs for AllToAll (even though they map to different logical dimensions)
  2. Adding explicit nullptr checks in all functions that consume CommunicationInfo
  3. Documenting this assumption clearly in the CommunicationInfo struct definition

}
} else {
NVF_ERROR(e->isA<ReductionOp>() || e->isA<SqueezeOp>());
Expand Down Expand Up @@ -441,10 +481,12 @@ Layout getCommunicationLayout(
// For the following communication types, the sharded_id does not have to be
// outermost in allocation domain. Nonetheless, `tv` still needs to be
// contiguous and therefore .contiguous() at the beginning of this function.
// TODO(prmishra): Fix the layout for AllToAll.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on where we relayout the input/output to be compliant with alltoall requirements, this function and potentially ReorderShardedAxisPass will be affected. I will do it in a following PR once the current PR has been reviewed and we agree on the approach for reordering input/output of alltoall

if (type == CommunicationType::Reduce ||
type == CommunicationType::Allreduce ||
type == CommunicationType::Broadcast ||
type == CommunicationType::SendRecv) {
type == CommunicationType::SendRecv ||
type == CommunicationType::AllToAll) {
return layout;
}

Expand Down Expand Up @@ -568,6 +610,9 @@ std::vector<Expr*> convertSingleOpToCommunication(
case CommunicationType::Reduce:
lowerToReduce(input_tv, output_tv, op_type(e), backend, comms);
break;
case CommunicationType::AllToAll:
lowerToAllToAll(input_tv, output_tv, backend, comms);
break;
}

return comms;
Expand Down
73 changes: 73 additions & 0 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ std::ostream& operator<<(std::ostream& os, const CommunicationType& type) {
case CommunicationType::SendRecv:
os << "SendRecv";
break;
case CommunicationType::AllToAll:
os << "AllToAll";
break;
default:
NVF_THROW("unrecognized CommunicationType: ", type);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
default:
NVF_THROW("unrecognized CommunicationType: ", type);

trust -Wswitch

}
Expand Down Expand Up @@ -152,6 +155,7 @@ bool hasRoot(CommunicationType type) {
case CommunicationType::Allgather:
case CommunicationType::Allreduce:
case CommunicationType::ReduceScatter:
case CommunicationType::AllToAll:
return false;
default:
NVF_THROW("unrecognized CommunicationType: ", type);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
default:
NVF_THROW("unrecognized CommunicationType: ", type);

You may need an std::unreachable() after the switch.

Expand All @@ -169,6 +173,7 @@ bool isReduction(CommunicationType type) {
case CommunicationType::Scatter:
case CommunicationType::Broadcast:
case CommunicationType::SendRecv:
case CommunicationType::AllToAll:
return false;
default:
NVF_THROW("unrecognized CommunicationType: ", type);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Expand Down Expand Up @@ -605,6 +610,71 @@ c10::intrusive_ptr<c10d::Work> postSendRecv(
/*tag=*/0);
}
}

c10::intrusive_ptr<c10d::Work> postAllToAll(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: missing team_size == 1 edge case - when only one device, no communication needed, should return early after local copy (like postBroadcast at line 344)

Communication* communication,
DeviceIdxType my_device_index,
c10d::Backend* backend,
at::Tensor input_tensor,
at::Tensor output_tensor) {
NVF_ERROR(
isTvContiguous(communication->in()),
"Input tensor is not contiguous: ",
communication->in(),
" contiguity: ",
communication->in()->domain()->getContiguityString());
NVF_ERROR(
isTvContiguous(communication->out()),
"Output tensor is not contiguous: ",
communication->out(),
" contiguity: ",
communication->out()->domain()->getContiguityString());

// input_tv = [DIDx(d), n/d, m, ...]
// output_tv = [n, DIDx(d), m/d, ...]
// `n`: gathered dimension
// `m`: scattered dimension
// For alltoall correctness, we split `m` and reorder as [DIDx(d), d, n/d,
// m/d, ...] such that alltoall_base splits across the `d` dimension.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment describing the input/output tensor shapes appears to be inconsistent with the actual implementation below.

The comment states:

  • input_tv = [DIDx(d), n/d, m, ...]
  • m: scattered dimension (at position 2)

But the code at lines 644-651 operates on dimension 1 (input_sizes.at(1)), treating it as the scattered dimension. Based on the test case in test_alltoall, the actual runtime input tensor shape is [n/d, m, k, ...] (after permutation), where:

  • Position 0: n/d (gathered dimension)
  • Position 1: m (scattered dimension)
  • Position 2+: other dimensions

The comment should be updated to reflect the actual runtime tensor layout rather than the logical sharding representation, or clarify that it's describing logical sharding rather than physical memory layout.


int64_t d = communication->team_size();
auto input_sizes = input_tensor.sizes();

NVF_CHECK(
input_sizes.at(1) % d == 0,
"Scattered dimension must be divisible by the team size");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message says "Scattered dimension must be divisible by the team size" but this check only validates one dimension.

For a complete AllToAll operation, both the scattered dimension (at position 1 in the runtime tensor) and the gathered dimension (at position 0) should be validated:

  1. The scattered dimension input_sizes.at(1) must be divisible by d (currently checked ✓)
  2. The gathered dimension should result in an output where the gathered size equals d * input_sizes.at(0)

Consider adding validation for the output tensor dimensions as well to catch shape mismatches early, rather than relying on assertBuffersHaveSameSize at line 668 which only provides a generic error.


std::vector<int64_t> input_reshape_sizes(
input_sizes.begin(), input_sizes.end());
input_reshape_sizes.at(1) = d;
input_reshape_sizes.insert(
input_reshape_sizes.begin() + 2, input_sizes.at(1) / d);
auto reshaped_input = input_tensor.reshape(input_reshape_sizes);

std::vector<int64_t> permute_dims(input_reshape_sizes.size());
std::iota(permute_dims.begin(), permute_dims.end(), 0);
std::swap(permute_dims[0], permute_dims[1]);

auto reordered_input = reshaped_input.permute(permute_dims).contiguous();

auto flattened_input_tensor = viewAsCompact(reordered_input);
auto flattened_output_tensor = viewAsCompact(output_tensor);

// alltoall_base requires even splits of the input and output tensors.
auto input_splits = at::tensor_split(
flattened_input_tensor, communication->team_size(), /*dim=*/0);
auto output_splits = at::tensor_split(
flattened_output_tensor, communication->team_size(), /*dim=*/0);
assertBuffersHaveSameSize(input_splits, output_splits);

std::vector<int64_t> empty_split_sizes;
return backend->alltoall_base(
flattened_output_tensor,
flattened_input_tensor,
empty_split_sizes,
empty_split_sizes,
/*options=*/{});
}
} // namespace

c10::intrusive_ptr<c10d::Work> postSingleCommunication(
Expand Down Expand Up @@ -691,6 +761,9 @@ c10::intrusive_ptr<c10d::Work> postSingleCommunication(
backend,
input_tensor,
output_tensor);
case CommunicationType::AllToAll:
return postAllToAll(
communication, my_device_index, backend, input_tensor, output_tensor);
default:
NVF_THROW("Wrong communication type: ", communication->type());
return nullptr;
Expand Down
3 changes: 2 additions & 1 deletion csrc/multidevice/communication.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ enum class CommunicationType {
Allreduce,
ReduceScatter,
Broadcast,
SendRecv
SendRecv,
AllToAll
};

std::ostream& operator<<(std::ostream& os, const CommunicationType& type);
Expand Down
71 changes: 71 additions & 0 deletions tests/python/multidevice/test_multidevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,3 +421,74 @@ def test_binary(multidevice_test):
(z,) = fd.execute([x_ref.cuda(), y])

torch.testing.assert_close(z, multidevice_test.shard_tensor(z_ref, z_tv))


@pytest.mark.mpi
def test_alltoall(multidevice_test):
d = multidevice_test.size
mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d))
k, m, n = 3, 40, 56

assert (
m % d == 0 and n % d == 0
), "Sharded dimensions must be divisible by the team size"

with FusionDefinition() as fd:
inp = fd.define_tensor((k, m, n), contiguity=True, dtype=DataType.Half)
# Ideally, we could have exposed and split the allocation domain of
# alltoall input and output, as follows:
# Input of shape [k, m, n] sharded on n
# Output of shape [k, m, n] sharded on m
#
# Where d is an outer split from `n` and d* is an outer_split from `m`:
# input [b, m, DIDx(d), n/d]
# alltoall_input [d*, b, m/d*, DIDx(d), n/d]
# d* is reordered outermost. `alltoall_single` will slice on the
# outermost dimension.
# alltoall_output [d, b, DIDx(d*), m/d*, n/d]
# d is gathered back. AllToAll places data from each device in
# memory in order of rank.
# output [b, DIDx(d*), m/d, n]
# Same as sharding original input along row dimension instead of
# column.
#
# However, tensors corresponding to allocation domains with non-adjacent
# splits fail stride validation. To avoid this, I permute the gathered
# and scattered dimensions to be outermost in that order. `postAllToAll`
# will split the scattered dimension and reorder as
# [DIDx(d), d, n/d, m/d, k] and make it contiguous. The output will be
# [DIDx(d*), n, m, k]. This avoids the non-adjacent split in output and
# avoids exposing the non-adjacent split of `m` in input to the fusion
# definition. I am using `permute` instead of setting the allocation
# domain to avoid another copy.
all2all_inp = fd.ops.permute(inp, dims=[2, 1, 0])
all2all_out = fd.ops.set(all2all_inp)
out = fd.ops.permute(all2all_out, dims=[2, 1, 0])
fd.add_output(out)

inp.set_device_mesh(mesh)
inp.outer_split(2, d)
inp.axis(2).parallelize(nvfuser.ParallelType.mesh_x)

all2all_inp.set_device_mesh(mesh)
all2all_inp.outer_split(0, d)
all2all_inp.axis(0).parallelize(
nvfuser.ParallelType.mesh_x
) # [DIDx(d), n, d * m, k]

all2all_out.set_device_mesh(mesh)
all2all_out.outer_split(1, d)
all2all_out.axis(1).parallelize(
nvfuser.ParallelType.mesh_x
) # [d * n, DIDx(d), m, k]
all2all_out.set_allocation_domain(all2all_out.get_loop_domain(), True)

out.set_device_mesh(mesh)
out.outer_split(1, d)
out.axis(1).parallelize(nvfuser.ParallelType.mesh_x)
out.set_allocation_domain(out.get_loop_domain(), True)

in_tensor = torch.randn(k, m, n, dtype=torch.float16)
sharded = multidevice_test.shard_tensor(in_tensor, inp)
(out,) = fd.execute([sharded])
torch.testing.assert_close(out.cpu(), multidevice_test.shard_tensor(in_tensor, out))