|
6 | 6 | import torch |
7 | 7 |
|
8 | 8 | import nvfuser_direct as nvfuser |
9 | | -from nvfuser_direct import DataType, FusionDefinition |
| 9 | +from nvfuser_direct import DataType, FusionDefinition, PythonProfiler |
10 | 10 |
|
11 | 11 |
|
12 | 12 | # Avoid doing this when possible. This test started to exist before nvFuser |
@@ -197,50 +197,52 @@ def _multidevice_schedule(fd: FusionDefinition): |
197 | 197 | def test_linear_reduce_scatter(multidevice_direct_test): |
198 | 198 | d = multidevice_direct_test.size |
199 | 199 | mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d)) |
200 | | - e = 768 |
| 200 | + b, s, e = 3, 5, 7 |
201 | 201 |
|
202 | 202 | def _definition(fd: FusionDefinition): |
203 | | - inp = fd.define_tensor([-1, -1, d * e]) |
204 | | - weight = fd.define_tensor([e, d * e]) |
205 | | - out = fd.ops.linear(inp, weight, None) |
| 203 | + inp = fd.define_tensor([-1, d * s, d * e], dtype=DataType.BFloat16) |
| 204 | + weight = fd.define_tensor([-1, d * e], dtype=DataType.BFloat16) |
| 205 | + bias = fd.define_tensor([e], dtype=DataType.BFloat16) |
| 206 | + out = fd.ops.linear(inp, weight, bias) |
206 | 207 | fd.add_output(out) |
207 | 208 |
|
208 | 209 | def _multidevice_schedule(fd: FusionDefinition): |
209 | | - inp, weight = fd.fusion.inputs() |
| 210 | + inp, weight, bias = fd.fusion.inputs() |
210 | 211 | (out,) = fd.fusion.outputs() |
211 | | - for t in [inp, weight, out]: |
212 | | - t.set_device_mesh(mesh) |
213 | | - t.outer_split(-1, d) |
214 | | - t.axis(-2).parallelize(nvfuser.ParallelType.mesh_x) |
| 212 | + bias.set_device_mesh(mesh) |
| 213 | + for tv in [inp, weight, out]: |
| 214 | + tv.set_device_mesh(mesh) |
| 215 | + tv.split(-1, d, inner_split=False) |
| 216 | + tv.axis(-2).parallelize(nvfuser.ParallelType.mesh_x) |
215 | 217 |
|
216 | 218 | # Scatter |
217 | 219 | out.outer_split(1, d) |
218 | 220 | out.axis(1).parallelize(nvfuser.ParallelType.mesh_x) |
219 | 221 |
|
220 | 222 | torch.cuda.set_device(multidevice_direct_test.local_rank) |
221 | 223 |
|
222 | | - # set b=1 as a temporary fix for the test to pass. |
223 | | - # TODO: set b>1 once reduce scatter is fixed. |
224 | | - b, s = 2, 1024 |
225 | | - unsharded_inp = torch.randn(b, s, d * e) |
226 | | - unsharded_weight = torch.randn(e, d * e) |
227 | | - |
| 224 | + unsharded_inp = torch.randint(-2, 3, (b, d * s, d * e)).to(torch.bfloat16) |
| 225 | + unsharded_weight = torch.randint(-2, 3, (e, d * e)).to(torch.bfloat16) |
| 226 | + bias = torch.randint(-2, 3, (e,)).to(torch.bfloat16) |
228 | 227 | inp = multidevice_direct_test.shard_tensor(unsharded_inp, -1, mesh) |
229 | 228 | weight = multidevice_direct_test.shard_tensor(unsharded_weight, -1, mesh) |
230 | 229 |
|
231 | 230 | with FusionDefinition() as fd: |
232 | 231 | _definition(fd) |
233 | 232 | _multidevice_schedule(fd) |
234 | 233 |
|
235 | | - (out,) = fd.execute([inp, weight]) |
| 234 | + with PythonProfiler() as prof: |
| 235 | + (out,) = fd.execute([inp, weight, bias.cuda()]) |
236 | 236 |
|
237 | | - unsharded_out = torch.nn.functional.linear(unsharded_inp, unsharded_weight, None) |
238 | | - # rtol is the same as the default for fp32. atol is slightly increased. |
| 237 | + # Only one reduce scatter kernel should be scheduled. |
| 238 | + assert len( |
| 239 | + [kp for kp in prof.profile.kernel_profiles if kp.scheduler == "communication"] |
| 240 | + ) == (1 if d > 1 else 0) |
| 241 | + |
| 242 | + unsharded_out = torch.nn.functional.linear(unsharded_inp, unsharded_weight, bias) |
239 | 243 | torch.testing.assert_close( |
240 | 244 | out, |
241 | 245 | multidevice_direct_test.shard_tensor(unsharded_out, 1, mesh), |
242 | | - rtol=1.3e-6, |
243 | | - atol=1e-3, |
244 | 246 | ) |
245 | 247 |
|
246 | 248 |
|
|
0 commit comments