-
Notifications
You must be signed in to change notification settings - Fork 69
Fix decomposeLinearWithBias to shard all created tensorviews
#5563
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -84,7 +84,7 @@ KernelArgumentHolder CommunicationExecutor::run( | |
| group_id_); | ||
| SegmentProfiler& sprof = FusionProfiler::segment(group_id_); | ||
| sprof.inputBytesAccessed(computeBytes(args)); | ||
| sprof.scheduler(toString(SchedulerType::ExprEval)); | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Caused the wrong scheduler name in profiler output. |
||
| sprof.scheduler(toString(SchedulerType::Communication)); | ||
| sprof.startKernel(); | ||
| } | ||
| NVF_ERROR(host_ir_container_, "Need to compile before you can run."); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,7 +6,7 @@ | |
| import torch | ||
|
|
||
| import nvfuser_direct as nvfuser | ||
| from nvfuser_direct import DataType, FusionDefinition | ||
| from nvfuser_direct import DataType, FusionDefinition, PythonProfiler | ||
|
|
||
|
|
||
| # Avoid doing this when possible. This test started to exist before nvFuser | ||
|
|
@@ -200,47 +200,57 @@ def test_linear_reduce_scatter(multidevice_direct_test): | |
| e = 768 | ||
|
|
||
| def _definition(fd: FusionDefinition): | ||
| inp = fd.define_tensor([-1, -1, d * e]) | ||
| weight = fd.define_tensor([e, d * e]) | ||
| out = fd.ops.linear(inp, weight, None) | ||
| inp = fd.define_tensor([-1, -1, d * e], dtype=DataType.BFloat16) | ||
| weight = fd.define_tensor([-1, d * e], dtype=DataType.BFloat16) | ||
| bias = fd.define_tensor([e], dtype=DataType.BFloat16) | ||
| out = fd.ops.linear(inp, weight, bias) | ||
| fd.add_output(out) | ||
|
|
||
| def _multidevice_schedule(fd: FusionDefinition): | ||
| inp, weight = fd.fusion.inputs() | ||
| inp, weight, bias = fd.fusion.inputs() | ||
| (out,) = fd.fusion.outputs() | ||
| for t in [inp, weight, out]: | ||
| t.set_device_mesh(mesh) | ||
| t.split(-1, d, inner_split=False) | ||
| t.axis(-2).parallelize(nvfuser.ParallelType.mesh_x) | ||
| bias.set_device_mesh(mesh) | ||
| for tv in [inp, weight, out]: | ||
| tv.set_device_mesh(mesh) | ||
| tv.split(-1, d, inner_split=False) | ||
| tv.axis(-2).parallelize(nvfuser.ParallelType.mesh_x) | ||
|
|
||
| # Scatter | ||
| out.split(1, d, inner_split=False) | ||
| out.axis(1).parallelize(nvfuser.ParallelType.mesh_x) | ||
|
|
||
| torch.cuda.set_device(multidevice_direct_test.local_rank) | ||
|
|
||
| # set b=1 as a temporary fix for the test to pass. | ||
| # TODO: set b>1 once reduce scatter is fixed. | ||
| b, s = 2, 1024 | ||
| unsharded_inp = torch.randn(b, s, d * e) | ||
| unsharded_weight = torch.randn(e, d * e) | ||
|
|
||
| b, s = 2, 8 | ||
Priya2698 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| unsharded_inp = torch.randint(-2, 3, (b, s, d * e)).to(torch.bfloat16) | ||
| unsharded_weight = torch.randint(-2, 3, (e, d * e)).to(torch.bfloat16) | ||
| bias = torch.randint(-2, 3, (e,)).to(torch.bfloat16) | ||
| inp = multidevice_direct_test.shard_tensor(unsharded_inp, -1, mesh) | ||
| weight = multidevice_direct_test.shard_tensor(unsharded_weight, -1, mesh) | ||
|
|
||
| with FusionDefinition() as fd: | ||
| _definition(fd) | ||
| _multidevice_schedule(fd) | ||
|
|
||
| (out,) = fd.execute([inp, weight]) | ||
| with PythonProfiler() as prof: | ||
| (out,) = fd.execute([inp, weight, bias.cuda()]) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this synchronize? Could we miss kernels?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Is this what you are referring to?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a difference between cudaStreamSynchronize and cudaDeviceSynchronize though. The former blocks the stream and the latter blocks the host.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right. I assumed FusionProfiler/PythonProfiler synchronize at start but not on stop. So I will add an explicit call here. Note for myself: See if FusionProfiler should synchronize before reading data. |
||
|
|
||
| unsharded_out = torch.nn.functional.linear(unsharded_inp, unsharded_weight, None) | ||
| # rtol is the same as the default for fp32. atol is slightly increased. | ||
| # Only one reduce scatter kernel should be scheduled. | ||
| assert ( | ||
| len( | ||
| [ | ||
| kp | ||
| for kp in prof.profile.kernel_profiles | ||
| if kp.scheduler == "communication" | ||
| ] | ||
| ) | ||
| == 1 | ||
| ) | ||
|
|
||
| unsharded_out = torch.nn.functional.linear(unsharded_inp, unsharded_weight, bias) | ||
| torch.testing.assert_close( | ||
| out, | ||
| multidevice_direct_test.shard_tensor(unsharded_out, 1, mesh), | ||
| rtol=1.3e-6, | ||
| atol=1e-3, | ||
| ) | ||
|
|
||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.