Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions csrc/host_ir/lower_to_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,40 @@ void lowerToReduceScatter(
backend));
}

void lowerToAllToAll(
TensorView* input_tv,
TensorView* output_tv,
const CommunicatorBackend backend,
std::vector<Expr*>& comms) {
const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
NVF_ERROR_EQ(
sender_mesh.rank(),
1,
"AllToAll sender mesh must be a 1D mesh. Given ",
sender_mesh);
NVF_ERROR_EQ(
receiver_mesh.rank(),
1,
"AllToAll receiver mesh must be a 1D mesh. Given ",
receiver_mesh);
NVF_ERROR_EQ(
sender_mesh,
receiver_mesh,
"AllToAll sender and receiver meshes must be the same. Given ",
sender_mesh,
" and ",
receiver_mesh);
comms.push_back(IrBuilder::create<Communication>(
CommunicationType::AllToAll,
output_tv,
input_tv,
sender_mesh.vector(),
/*root=*/-1,
c10d::ReduceOp::RedOpType::UNUSED,
backend));
}

IterDomain* getLogicalFromLoopId(TensorView* tv, IterDomain* loop_id) {
std::unordered_set<IterDomain*> logical_ids =
getInputsInTargetDomain({loop_id}, tv->getLogicalDomain());
Expand Down Expand Up @@ -379,8 +413,14 @@ CommunicationInfo getCommunicationInfo(Expr* e) {
IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did);
IterDomain* c_logical_id = getLogicalFromLoopId(consumer, c_loop_did);
// TODO(#4604): This is problematic for 2D sharding.
fill_communication_info(
CommunicationType::SendRecv, p_logical_id, c_logical_id);

if (c_logical_id == p2c_map.at(p_logical_id)) {
fill_communication_info(
CommunicationType::SendRecv, p_logical_id, c_logical_id);
} else {
fill_communication_info(
CommunicationType::AllToAll, nullptr, nullptr);
}
}
} else {
NVF_ERROR(e->isA<ReductionOp>() || e->isA<SqueezeOp>());
Expand Down Expand Up @@ -441,10 +481,12 @@ Layout getCommunicationLayout(
// For the following communication types, the sharded_id does not have to be
// outermost in allocation domain. Nonetheless, `tv` still needs to be
// contiguous and therefore .contiguous() at the beginning of this function.
// TODO(prmishra): Fix the layout for AllToAll.
if (type == CommunicationType::Reduce ||
type == CommunicationType::Allreduce ||
type == CommunicationType::Broadcast ||
type == CommunicationType::SendRecv) {
type == CommunicationType::SendRecv ||
type == CommunicationType::AllToAll) {
return layout;
}

Expand Down Expand Up @@ -568,6 +610,9 @@ std::vector<Expr*> convertSingleOpToCommunication(
case CommunicationType::Reduce:
lowerToReduce(input_tv, output_tv, op_type(e), backend, comms);
break;
case CommunicationType::AllToAll:
lowerToAllToAll(input_tv, output_tv, backend, comms);
break;
}

return comms;
Expand Down
73 changes: 73 additions & 0 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ std::ostream& operator<<(std::ostream& os, const CommunicationType& type) {
case CommunicationType::SendRecv:
os << "SendRecv";
break;
case CommunicationType::AllToAll:
os << "AllToAll";
break;
default:
NVF_THROW("unrecognized CommunicationType: ", type);
}
Expand Down Expand Up @@ -152,6 +155,7 @@ bool hasRoot(CommunicationType type) {
case CommunicationType::Allgather:
case CommunicationType::Allreduce:
case CommunicationType::ReduceScatter:
case CommunicationType::AllToAll:
return false;
default:
NVF_THROW("unrecognized CommunicationType: ", type);
Expand All @@ -169,6 +173,7 @@ bool isReduction(CommunicationType type) {
case CommunicationType::Scatter:
case CommunicationType::Broadcast:
case CommunicationType::SendRecv:
case CommunicationType::AllToAll:
return false;
default:
NVF_THROW("unrecognized CommunicationType: ", type);
Expand Down Expand Up @@ -605,6 +610,71 @@ c10::intrusive_ptr<c10d::Work> postSendRecv(
/*tag=*/0);
}
}

c10::intrusive_ptr<c10d::Work> postAllToAll(
Communication* communication,
DeviceIdxType my_device_index,
c10d::Backend* backend,
at::Tensor input_tensor,
at::Tensor output_tensor) {
NVF_ERROR(
isTvContiguous(communication->in()),
"Input tensor is not contiguous: ",
communication->in(),
" contiguity: ",
communication->in()->domain()->getContiguityString());
NVF_ERROR(
isTvContiguous(communication->out()),
"Output tensor is not contiguous: ",
communication->out(),
" contiguity: ",
communication->out()->domain()->getContiguityString());

// input_tv = [DIDx(d), n/d, m, ...]
// output_tv = [n, DIDx(d), m/d, ...]
// `n`: gathered dimension
// `m`: scattered dimension
// For alltoall correctness, we split `m` and reorder as [DIDx(d), d, n/d,
// m/d, ...] such that alltoall_base splits across the `d` dimension.

int64_t d = communication->team_size();
auto input_sizes = input_tensor.sizes();

NVF_CHECK(
input_sizes.at(1) % d == 0,
"Scattered dimension must be divisible by the team size");

std::vector<int64_t> input_reshape_sizes(
input_sizes.begin(), input_sizes.end());
input_reshape_sizes.at(1) = d;
input_reshape_sizes.insert(
input_reshape_sizes.begin() + 2, input_sizes.at(1) / d);
auto reshaped_input = input_tensor.reshape(input_reshape_sizes);

std::vector<int64_t> permute_dims(input_reshape_sizes.size());
std::iota(permute_dims.begin(), permute_dims.end(), 0);
std::swap(permute_dims[0], permute_dims[1]);

auto reordered_input = reshaped_input.permute(permute_dims).contiguous();

auto flattened_input_tensor = viewAsCompact(reordered_input);
auto flattened_output_tensor = viewAsCompact(output_tensor);

// alltoall_base requires even splits of the input and output tensors.
auto input_splits = at::tensor_split(
flattened_input_tensor, communication->team_size(), /*dim=*/0);
auto output_splits = at::tensor_split(
flattened_output_tensor, communication->team_size(), /*dim=*/0);
assertBuffersHaveSameSize(input_splits, output_splits);

std::vector<int64_t> empty_split_sizes;
return backend->alltoall_base(
flattened_output_tensor,
flattened_input_tensor,
empty_split_sizes,
empty_split_sizes,
/*options=*/{});
}
} // namespace

c10::intrusive_ptr<c10d::Work> postSingleCommunication(
Expand Down Expand Up @@ -691,6 +761,9 @@ c10::intrusive_ptr<c10d::Work> postSingleCommunication(
backend,
input_tensor,
output_tensor);
case CommunicationType::AllToAll:
return postAllToAll(
communication, my_device_index, backend, input_tensor, output_tensor);
default:
NVF_THROW("Wrong communication type: ", communication->type());
return nullptr;
Expand Down
3 changes: 2 additions & 1 deletion csrc/multidevice/communication.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ enum class CommunicationType {
Allreduce,
ReduceScatter,
Broadcast,
SendRecv
SendRecv,
AllToAll
};

std::ostream& operator<<(std::ostream& os, const CommunicationType& type);
Expand Down
73 changes: 73 additions & 0 deletions tests/python/multidevice/test_multidevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,3 +447,76 @@ def test_binary(multidevice_test):
torch.testing.assert_close(
z, multidevice_test.shard_tensor(x.float() + y.float(), 0, mesh)
)


@pytest.mark.mpi
def test_alltoall(multidevice_test):
d = multidevice_test.size
mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d))
k, m, n = 3, 40, 56

assert (
m % d == 0 and n % d == 0
), "Sharded dimensions must be divisible by the team size"

with FusionDefinition() as fd:
inp = fd.define_tensor((k, m, n), contiguity=True, dtype=DataType.Half)
# Ideally, we could have exposed and split the allocation domain of
# alltoall input and output, as follows:
# Input of shape [k, m, n] sharded on n
# Output of shape [k, m, n] sharded on m
#
# Where d is an outer split from `n` and d* is an outer_split from `m`:
# input [b, m, DIDx(d), n/d]
# alltoall_input [d*, b, m/d*, DIDx(d), n/d]
# d* is reordered outermost. `alltoall_single` will slice on the
# outermost dimension.
# alltoall_output [d, b, DIDx(d*), m/d*, n/d]
# d is gathered back. AllToAll places data from each device in
# memory in order of rank.
# output [b, DIDx(d*), m/d, n]
# Same as sharding original input along row dimension instead of
# column.
#
# However, tensors corresponding to allocation domains with non-adjacent
# splits fail stride validation. To avoid this, I permute the gathered
# and scattered dimensions to be outermost in that order. `postAllToAll`
# will split the scattered dimension and reorder as
# [DIDx(d), d, n/d, m/d, k] and make it contiguous. The output will be
# [DIDx(d*), n, m, k]. This avoids the non-adjacent split in output and
# avoids exposing the non-adjacent split of `m` in input to the fusion
# definition. I am using `permute` instead of setting the allocation
# domain to avoid another copy.
all2all_inp = fd.ops.permute(inp, dims=[2, 1, 0])
all2all_out = fd.ops.set(all2all_inp)
out = fd.ops.permute(all2all_out, dims=[2, 1, 0])
fd.add_output(all2all_inp)
fd.add_output(all2all_out)
fd.add_output(out)

inp.set_device_mesh(mesh)
inp.outer_split(2, d)
inp.axis(2).parallelize(nvfuser.ParallelType.mesh_x)

all2all_inp.set_device_mesh(mesh)
all2all_inp.outer_split(0, d)
all2all_inp.axis(0).parallelize(
nvfuser.ParallelType.mesh_x
) # [DIDx(d), n, d * m, k]

all2all_out.set_device_mesh(mesh)
all2all_out.outer_split(1, d)
all2all_out.axis(1).parallelize(
nvfuser.ParallelType.mesh_x
) # [d * n, DIDx(d), m, k]
all2all_out.set_allocation_domain(all2all_out.get_loop_domain(), True)

out.set_device_mesh(mesh)
out.outer_split(1, d)
out.axis(1).parallelize(nvfuser.ParallelType.mesh_x)
out.set_allocation_domain(out.get_loop_domain(), True)

in_tensor = torch.randn(k, m, n, dtype=torch.float16)
sharded = multidevice_test.shard_tensor(in_tensor, 2, mesh)
(all2all_inp, all2all_out, out) = fd.execute([sharded])
torch.testing.assert_close(out, multidevice_test.shard_tensor(in_tensor, 1, mesh))