Skip to content

Commit 5d8efce

Browse files
authored
Fix decomposeLinearWithBias to shard all created tensorviews (#5563)
Some of the created tensorviews were not sharded consistently and hence led to more communication than needed.
1 parent fbf3dfd commit 5d8efce

File tree

4 files changed

+52
-26
lines changed

4 files changed

+52
-26
lines changed

csrc/multidevice/propagation.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,13 @@ std::unordered_map<IterDomain*, IterDomain*> getRef2TargetMap(
8282
const TensorView* target,
8383
PropagateDirection direction) {
8484
if (direction == PropagateDirection::kForward) {
85-
return PairwiseLogicalDomainMap(ref, target).mapProducerToConsumer();
85+
return PairwiseLogicalDomainMap(ref, target)
86+
.mapBroadcast(false)
87+
.mapProducerToConsumer();
8688
}
87-
return PairwiseLogicalDomainMap(target, ref).mapConsumerToProducer();
89+
return PairwiseLogicalDomainMap(target, ref)
90+
.mapBroadcast(false)
91+
.mapConsumerToProducer();
8892
}
8993

9094
// Propagates the given device/stream ids from ref to target.

csrc/preseg_passes/decompose_reshardings.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,6 @@ void decomposeRowParallelLinearWithBias(Fusion* fusion) {
314314
}
315315

316316
auto* without_bias = linear(linear_op->inA(), linear_op->inB());
317-
TransformReplay::selfReplay(out->domain(), without_bias->domain());
318317

319318
TensorView* broadcasted_bias = [&]() {
320319
const int64_t rank_after_broadcast = std::ssize(
@@ -330,8 +329,29 @@ void decomposeRowParallelLinearWithBias(Fusion* fusion) {
330329

331330
TensorView* new_out =
332331
maybeCastOp(out->dtype(), add(without_bias, broadcasted_bias));
333-
TransformReplay::selfReplay(out->domain(), new_out->domain());
332+
334333
ir_utils::replaceValInAllExprInputsAndFusionOutputs(out, new_out);
334+
335+
// Shard without_bias to match new_out so that reduction ID is properly
336+
// sharded.
337+
TransformReplay::selfReplay(out->domain(), without_bias->domain());
338+
TransformReplay::selfReplay(out->domain(), new_out->domain());
339+
// Backpropagate shardings to consistently shard all intermediate
340+
// expressions. Forward propagating may miss sharding tensorviews
341+
// on the path between `bias` and `new_out`.
342+
for (Expr* expr : StmtSort::getExprsBetween(
343+
{without_bias, broadcasted_bias}, {new_out}) |
344+
std::views::reverse) {
345+
for (auto* output : ir_utils::filterByType<TensorView>(expr->outputs())) {
346+
for (auto* input : ir_utils::filterByType<TensorView>(expr->inputs())) {
347+
shardLoopLike(
348+
/*ref=*/output,
349+
/*target=*/input,
350+
deviceAndStreamParallelTypes(),
351+
PropagateDirection::kBackward);
352+
}
353+
}
354+
}
335355
}
336356
}
337357

csrc/runtime/communication_executor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ KernelArgumentHolder CommunicationExecutor::run(
8585
group_id_);
8686
SegmentProfiler& sprof = FusionProfiler::segment(group_id_);
8787
sprof.inputBytesAccessed(computeBytes(args));
88-
sprof.scheduler(toString(SchedulerType::ExprEval));
88+
sprof.scheduler(toString(SchedulerType::Communication));
8989
sprof.startKernel();
9090
}
9191
NVF_ERROR(host_ir_container_, "Need to compile before you can run.");

tests/python/multidevice/test_matmul.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77

88
import nvfuser_direct as nvfuser
9-
from nvfuser_direct import DataType, FusionDefinition
9+
from nvfuser_direct import DataType, FusionDefinition, PythonProfiler
1010

1111

1212
# Avoid doing this when possible. This test started to exist before nvFuser
@@ -197,50 +197,52 @@ def _multidevice_schedule(fd: FusionDefinition):
197197
def test_linear_reduce_scatter(multidevice_direct_test):
198198
d = multidevice_direct_test.size
199199
mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d))
200-
e = 768
200+
b, s, e = 3, 5, 7
201201

202202
def _definition(fd: FusionDefinition):
203-
inp = fd.define_tensor([-1, -1, d * e])
204-
weight = fd.define_tensor([e, d * e])
205-
out = fd.ops.linear(inp, weight, None)
203+
inp = fd.define_tensor([-1, d * s, d * e], dtype=DataType.BFloat16)
204+
weight = fd.define_tensor([-1, d * e], dtype=DataType.BFloat16)
205+
bias = fd.define_tensor([e], dtype=DataType.BFloat16)
206+
out = fd.ops.linear(inp, weight, bias)
206207
fd.add_output(out)
207208

208209
def _multidevice_schedule(fd: FusionDefinition):
209-
inp, weight = fd.fusion.inputs()
210+
inp, weight, bias = fd.fusion.inputs()
210211
(out,) = fd.fusion.outputs()
211-
for t in [inp, weight, out]:
212-
t.set_device_mesh(mesh)
213-
t.outer_split(-1, d)
214-
t.axis(-2).parallelize(nvfuser.ParallelType.mesh_x)
212+
bias.set_device_mesh(mesh)
213+
for tv in [inp, weight, out]:
214+
tv.set_device_mesh(mesh)
215+
tv.split(-1, d, inner_split=False)
216+
tv.axis(-2).parallelize(nvfuser.ParallelType.mesh_x)
215217

216218
# Scatter
217219
out.outer_split(1, d)
218220
out.axis(1).parallelize(nvfuser.ParallelType.mesh_x)
219221

220222
torch.cuda.set_device(multidevice_direct_test.local_rank)
221223

222-
# set b=1 as a temporary fix for the test to pass.
223-
# TODO: set b>1 once reduce scatter is fixed.
224-
b, s = 2, 1024
225-
unsharded_inp = torch.randn(b, s, d * e)
226-
unsharded_weight = torch.randn(e, d * e)
227-
224+
unsharded_inp = torch.randint(-2, 3, (b, d * s, d * e)).to(torch.bfloat16)
225+
unsharded_weight = torch.randint(-2, 3, (e, d * e)).to(torch.bfloat16)
226+
bias = torch.randint(-2, 3, (e,)).to(torch.bfloat16)
228227
inp = multidevice_direct_test.shard_tensor(unsharded_inp, -1, mesh)
229228
weight = multidevice_direct_test.shard_tensor(unsharded_weight, -1, mesh)
230229

231230
with FusionDefinition() as fd:
232231
_definition(fd)
233232
_multidevice_schedule(fd)
234233

235-
(out,) = fd.execute([inp, weight])
234+
with PythonProfiler() as prof:
235+
(out,) = fd.execute([inp, weight, bias.cuda()])
236236

237-
unsharded_out = torch.nn.functional.linear(unsharded_inp, unsharded_weight, None)
238-
# rtol is the same as the default for fp32. atol is slightly increased.
237+
# Only one reduce scatter kernel should be scheduled.
238+
assert len(
239+
[kp for kp in prof.profile.kernel_profiles if kp.scheduler == "communication"]
240+
) == (1 if d > 1 else 0)
241+
242+
unsharded_out = torch.nn.functional.linear(unsharded_inp, unsharded_weight, bias)
239243
torch.testing.assert_close(
240244
out,
241245
multidevice_direct_test.shard_tensor(unsharded_out, 1, mesh),
242-
rtol=1.3e-6,
243-
atol=1e-3,
244246
)
245247

246248

0 commit comments

Comments
 (0)