Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 49 additions & 3 deletions csrc/host_ir/lower_to_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,40 @@ void lowerToReduceScatter(
backend));
}

void lowerToAllToAll(
TensorView* input_tv,
TensorView* output_tv,
const CommunicatorBackend backend,
std::vector<Expr*>& comms) {
const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
NVF_ERROR_EQ(
sender_mesh.rank(),
1,
"AllToAll sender mesh must be a 1D mesh. Given ",
sender_mesh);
NVF_ERROR_EQ(
receiver_mesh.rank(),
1,
"AllToAll receiver mesh must be a 1D mesh. Given ",
receiver_mesh);
NVF_ERROR_EQ(
sender_mesh,
receiver_mesh,
"AllToAll sender and receiver meshes must be the same. Given ",
sender_mesh,
" and ",
receiver_mesh);
comms.push_back(IrBuilder::create<Communication>(
CommunicationType::AllToAll,
output_tv,
input_tv,
sender_mesh.vector(),
/*root=*/-1,
c10d::ReduceOp::RedOpType::UNUSED,
backend));
}

IterDomain* getLogicalFromLoopId(TensorView* tv, IterDomain* loop_id) {
std::unordered_set<IterDomain*> logical_ids =
getInputsInTargetDomain({loop_id}, tv->getLogicalDomain());
Expand Down Expand Up @@ -379,8 +413,14 @@ CommunicationInfo getCommunicationInfo(Expr* e) {
IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did);
IterDomain* c_logical_id = getLogicalFromLoopId(consumer, c_loop_did);
// TODO(#4604): This is problematic for 2D sharding.
fill_communication_info(
CommunicationType::SendRecv, p_logical_id, c_logical_id);

if (c_logical_id == p2c_map.at(p_logical_id)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as tests pass, I'm fine with this condition. But can you tell me the math behind it? If the two DIDs are mapped, isn't that a non-communication?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation of send recv assumes that the input and output are sharded on the same dimension, but have different meshes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is isResharding check at the beginning.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it -- we'll have to think about how to represent CollectivePermute for ring-based communication.

fill_communication_info(
CommunicationType::SendRecv, p_logical_id, c_logical_id);
} else {
fill_communication_info(
CommunicationType::AllToAll, p_logical_id, c_logical_id);
}
}
} else {
NVF_ERROR(e->isA<ReductionOp>() || e->isA<SqueezeOp>());
Expand Down Expand Up @@ -441,10 +481,13 @@ Layout getCommunicationLayout(
// For the following communication types, the sharded_id does not have to be
// outermost in allocation domain. Nonetheless, `tv` still needs to be
// contiguous and therefore .contiguous() at the beginning of this function.
// Note: We do not yet reorder for AllToAll and only support cases where the
// input and output do not require any reordering.
if (type == CommunicationType::Reduce ||
type == CommunicationType::Allreduce ||
type == CommunicationType::Broadcast ||
type == CommunicationType::SendRecv) {
type == CommunicationType::SendRecv ||
type == CommunicationType::AllToAll) {
return layout;
}

Expand Down Expand Up @@ -568,6 +611,9 @@ std::vector<Expr*> convertSingleOpToCommunication(
case CommunicationType::Reduce:
lowerToReduce(input_tv, output_tv, op_type(e), backend, comms);
break;
case CommunicationType::AllToAll:
lowerToAllToAll(input_tv, output_tv, backend, comms);
break;
}

return comms;
Expand Down
46 changes: 46 additions & 0 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ std::ostream& operator<<(std::ostream& os, const CommunicationType& type) {
case CommunicationType::SendRecv:
os << "SendRecv";
break;
case CommunicationType::AllToAll:
os << "AllToAll";
break;
default:
NVF_THROW("unrecognized CommunicationType: ", type);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
default:
NVF_THROW("unrecognized CommunicationType: ", type);

trust -Wswitch

}
Expand Down Expand Up @@ -152,6 +155,7 @@ bool hasRoot(CommunicationType type) {
case CommunicationType::Allgather:
case CommunicationType::Allreduce:
case CommunicationType::ReduceScatter:
case CommunicationType::AllToAll:
return false;
default:
NVF_THROW("unrecognized CommunicationType: ", type);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
default:
NVF_THROW("unrecognized CommunicationType: ", type);

You may need an std::unreachable() after the switch.

Expand All @@ -169,6 +173,7 @@ bool isReduction(CommunicationType type) {
case CommunicationType::Scatter:
case CommunicationType::Broadcast:
case CommunicationType::SendRecv:
case CommunicationType::AllToAll:
return false;
default:
NVF_THROW("unrecognized CommunicationType: ", type);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Expand Down Expand Up @@ -605,6 +610,44 @@ c10::intrusive_ptr<c10d::Work> postSendRecv(
/*tag=*/0);
}
}

c10::intrusive_ptr<c10d::Work> postAllToAll(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: missing team_size == 1 edge case - when only one device, no communication needed, should return early after local copy (like postBroadcast at line 344)

Communication* communication,
DeviceIdxType my_device_index,
c10d::Backend* backend,
at::Tensor input_tensor,
at::Tensor output_tensor) {
NVF_ERROR(
isTvContiguous(communication->in()),
"Input tensor is not contiguous: ",
communication->in(),
" contiguity: ",
communication->in()->domain()->getContiguityString());
NVF_ERROR(
isTvContiguous(communication->out()),
"Output tensor is not contiguous: ",
communication->out(),
" contiguity: ",
communication->out()->domain()->getContiguityString());

auto flattened_input_tensor = viewAsCompact(input_tensor);
auto flattened_output_tensor = viewAsCompact(output_tensor);

// alltoall_base requires even splits of the input and output tensors.
auto input_splits = at::tensor_split(
flattened_input_tensor, communication->team_size(), /*dim=*/0);
auto output_splits = at::tensor_split(
flattened_output_tensor, communication->team_size(), /*dim=*/0);
assertBuffersHaveSameSize(input_splits, output_splits);

std::vector<int64_t> empty_split_sizes;
return backend->alltoall_base(
flattened_output_tensor,
flattened_input_tensor,
empty_split_sizes,
empty_split_sizes,
/*options=*/{});
}
Comment on lines +611 to +647
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Missing team_size == 1 edge case handling. Other communication functions like postBroadcast (line 347) and postSendRecv (line 582) handle single-device teams by either returning early or doing local copies. AllToAll should handle this case to avoid unnecessary communication overhead and potential runtime errors.

Suggested change
c10::intrusive_ptr<c10d::Work> postAllToAll(
Communication* communication,
DeviceIdxType my_device_index,
c10d::Backend* backend,
at::Tensor input_tensor,
at::Tensor output_tensor) {
NVF_ERROR(
isTvContiguous(communication->in()),
"Input tensor is not contiguous: ",
communication->in(),
" contiguity: ",
communication->in()->domain()->getContiguityString());
NVF_ERROR(
isTvContiguous(communication->out()),
"Output tensor is not contiguous: ",
communication->out(),
" contiguity: ",
communication->out()->domain()->getContiguityString());
auto flattened_input_tensor = viewAsCompact(input_tensor);
auto flattened_output_tensor = viewAsCompact(output_tensor);
// alltoall_base requires even splits of the input and output tensors.
auto input_splits = at::tensor_split(
flattened_input_tensor, communication->team_size(), /*dim=*/0);
auto output_splits = at::tensor_split(
flattened_output_tensor, communication->team_size(), /*dim=*/0);
assertBuffersHaveSameSize(input_splits, output_splits);
std::vector<int64_t> empty_split_sizes;
return backend->alltoall_base(
flattened_output_tensor,
flattened_input_tensor,
empty_split_sizes,
empty_split_sizes,
/*options=*/{});
}
c10::intrusive_ptr<c10d::Work> postAllToAll(
Communication* communication,
DeviceIdxType my_device_index,
c10d::Backend* backend,
at::Tensor input_tensor,
at::Tensor output_tensor) {
if (communication->team().size() == 1) {
doLocalCopy(output_tensor, input_tensor);
return nullptr;
}
NVF_ERROR(
isTvContiguous(communication->in()),
"Input tensor is not contiguous: ",
communication->in(),
" contiguity: ",
communication->in()->domain()->getContiguityString());
NVF_ERROR(
isTvContiguous(communication->out()),
"Output tensor is not contiguous: ",
communication->out(),
" contiguity: ",
communication->out()->domain()->getContiguityString());
auto flattened_input_tensor = viewAsCompact(input_tensor);
auto flattened_output_tensor = viewAsCompact(output_tensor);
// alltoall_base requires even splits of the input and output tensors.
auto input_splits = at::tensor_split(
flattened_input_tensor, communication->team_size(), /*dim=*/0);
auto output_splits = at::tensor_split(
flattened_output_tensor, communication->team_size(), /*dim=*/0);
assertBuffersHaveSameSize(input_splits, output_splits);
std::vector<int64_t> empty_split_sizes;
return backend->alltoall_base(
flattened_output_tensor,
flattened_input_tensor,
empty_split_sizes,
empty_split_sizes,
/*options=*/{});
}

Comment on lines +611 to +647
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: missing team_size == 1 edge case - when team size is 1, no communication needed, just local copy (see postBroadcast at line 347 or postSendRecv at line 582)

Suggested change
c10::intrusive_ptr<c10d::Work> postAllToAll(
Communication* communication,
DeviceIdxType my_device_index,
c10d::Backend* backend,
at::Tensor input_tensor,
at::Tensor output_tensor) {
NVF_ERROR(
isTvContiguous(communication->in()),
"Input tensor is not contiguous: ",
communication->in(),
" contiguity: ",
communication->in()->domain()->getContiguityString());
NVF_ERROR(
isTvContiguous(communication->out()),
"Output tensor is not contiguous: ",
communication->out(),
" contiguity: ",
communication->out()->domain()->getContiguityString());
auto flattened_input_tensor = viewAsCompact(input_tensor);
auto flattened_output_tensor = viewAsCompact(output_tensor);
// alltoall_base requires even splits of the input and output tensors.
auto input_splits = at::tensor_split(
flattened_input_tensor, communication->team_size(), /*dim=*/0);
auto output_splits = at::tensor_split(
flattened_output_tensor, communication->team_size(), /*dim=*/0);
assertBuffersHaveSameSize(input_splits, output_splits);
std::vector<int64_t> empty_split_sizes;
return backend->alltoall_base(
flattened_output_tensor,
flattened_input_tensor,
empty_split_sizes,
empty_split_sizes,
/*options=*/{});
}
c10::intrusive_ptr<c10d::Work> postAllToAll(
Communication* communication,
DeviceIdxType my_device_index,
c10d::Backend* backend,
at::Tensor input_tensor,
at::Tensor output_tensor) {
if (communication->team_size() == 1) {
doLocalCopy(output_tensor, input_tensor);
return nullptr;
}
NVF_ERROR(
isTvContiguous(communication->in()),
"Input tensor is not contiguous: ",
communication->in(),
" contiguity: ",
communication->in()->domain()->getContiguityString());
NVF_ERROR(
isTvContiguous(communication->out()),
"Output tensor is not contiguous: ",
communication->out(),
" contiguity: ",
communication->out()->domain()->getContiguityString());
auto flattened_input_tensor = viewAsCompact(input_tensor);
auto flattened_output_tensor = viewAsCompact(output_tensor);
// alltoall_base requires even splits of the input and output tensors.
auto input_splits = at::tensor_split(
flattened_input_tensor, communication->team_size(), /*dim=*/0);
auto output_splits = at::tensor_split(
flattened_output_tensor, communication->team_size(), /*dim=*/0);
assertBuffersHaveSameSize(input_splits, output_splits);
std::vector<int64_t> empty_split_sizes;
return backend->alltoall_base(
flattened_output_tensor,
flattened_input_tensor,
empty_split_sizes,
empty_split_sizes,
/*options=*/{});
}

} // namespace

c10::intrusive_ptr<c10d::Work> postSingleCommunication(
Expand Down Expand Up @@ -691,6 +734,9 @@ c10::intrusive_ptr<c10d::Work> postSingleCommunication(
backend,
input_tensor,
output_tensor);
case CommunicationType::AllToAll:
return postAllToAll(
communication, my_device_index, backend, input_tensor, output_tensor);
default:
NVF_THROW("Wrong communication type: ", communication->type());
return nullptr;
Expand Down
3 changes: 2 additions & 1 deletion csrc/multidevice/communication.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ enum class CommunicationType {
Allreduce,
ReduceScatter,
Broadcast,
SendRecv
SendRecv,
AllToAll
};

std::ostream& operator<<(std::ostream& os, const CommunicationType& type);
Expand Down
31 changes: 31 additions & 0 deletions tests/python/multidevice/test_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,34 @@ def _multidevice_schedule(fd: FusionDefinition):
torch.testing.assert_close(
output, multidevice_test.shard_tensor_1d(unsharded.sum(0), 1, mesh)
)


# AllToAll patterns seen in expert parallelism
@pytest.mark.mpi
@pytest.mark.parametrize(
"inp_axis,out_axis", [(0, 1), (1, 0)], ids=["dispatch", "combine"]
)
def test_alltoall(multidevice_test, inp_axis, out_axis):
d = multidevice_test.size
mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d))
n = 3

with FusionDefinition() as fd:
inp = fd.define_tensor((d, d * n), contiguity=True, dtype=DataType.Half)
out = fd.ops.set(inp)
fd.add_output(out)

inp.set_device_mesh(mesh)
inp.outer_split(inp_axis, d)
inp.axis(inp_axis).parallelize(nvfuser.ParallelType.mesh_x)

out.set_device_mesh(mesh)
out.outer_split(out_axis, d)
out.axis(out_axis).parallelize(nvfuser.ParallelType.mesh_x)

in_tensor = torch.randn(d, d * n, dtype=torch.float16)
sharded = multidevice_test.shard_tensor(in_tensor, inp)
(out_tensor,) = fd.execute([sharded])
torch.testing.assert_close(
out_tensor, multidevice_test.shard_tensor(in_tensor, out)
)