Skip to content

Commit e224b26

Browse files
authored
Transformer sequence parallel forward (#5560)
Wall-clock time measured on `8 80GB H100 nodes`: | TE | nvFuser --- | --- | 2.5 ms | 2.1 ms
1 parent c49022d commit e224b26

File tree

3 files changed

+52
-38
lines changed

3 files changed

+52
-38
lines changed

tests/python/multidevice/benchmark_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@
33
# SPDX-License-Identifier: BSD-3-Clause
44

55
import torch
6+
from enum import auto, Enum
7+
8+
9+
class Parallelism(Enum):
10+
# https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/parallelisms.html#tensor-parallelism
11+
TENSOR_PARALLEL = auto()
12+
# https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/parallelisms.html#sequence-parallelism
13+
SEQUENCE_PARALLEL = auto()
614

715

816
def get_benchmark_fn(func, /, profile: bool):

tests/python/multidevice/test_transformer.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
create_sdpa_rng_tensors,
1313
is_pre_ampere,
1414
)
15-
from benchmark_utils import get_benchmark_fns
15+
from benchmark_utils import get_benchmark_fns, Parallelism
1616

1717

1818
@pytest.mark.mpi
@@ -349,24 +349,12 @@ def transformer_forward_definition(
349349
fd.add_output(out)
350350

351351

352-
def transformer_forward_multidevice_schedule(fd: FusionDefinition, num_devices: int):
352+
def transformer_forward_multidevice_schedule(
353+
fd: FusionDefinition, num_devices: int, parallelism: Parallelism
354+
):
353355
mesh = nvfuser.multidevice.DeviceMesh(range(num_devices))
354356
inputs = fd.fusion.inputs()
355-
inp = inputs[0]
356-
layernorm0_weight = inputs[1]
357-
layernorm0_bias = inputs[2]
358-
mha_linear0_weight = inputs[3]
359-
mha_linear0_bias = inputs[4]
360-
mha_linear1_weight = inputs[5]
361-
mha_linear1_bias = inputs[6]
362-
layernorm1_weight = inputs[7]
363-
layernorm1_bias = inputs[8]
364-
mlp_linear0_weight = inputs[9]
365-
mlp_linear0_bias = inputs[10]
366-
mlp_linear1_weight = inputs[11]
367-
mlp_linear1_bias = inputs[12]
368-
369-
for tv in [
357+
(
370358
inp,
371359
layernorm0_weight,
372360
layernorm0_bias,
@@ -380,9 +368,15 @@ def transformer_forward_multidevice_schedule(fd: FusionDefinition, num_devices:
380368
mlp_linear0_bias,
381369
mlp_linear1_weight,
382370
mlp_linear1_bias,
383-
]:
371+
) = inputs
372+
373+
for tv in inputs:
384374
tv.set_device_mesh(mesh)
385375

376+
if parallelism == Parallelism.SEQUENCE_PARALLEL:
377+
inp.outer_split(1, num_devices)
378+
inp.axis(1).parallelize(nvfuser.ParallelType.mesh_x)
379+
386380
for tv in [
387381
mha_linear0_weight,
388382
mha_linear0_bias,
@@ -413,8 +407,15 @@ def _assert_shape_dtype(
413407
is_pre_ampere(),
414408
reason="Flash Attention is only supported on Ampere and newer devices.",
415409
)
410+
@pytest.mark.parametrize(
411+
"parallelism",
412+
[Parallelism.TENSOR_PARALLEL, Parallelism.SEQUENCE_PARALLEL],
413+
ids=["tp", "sp"],
414+
)
416415
@pytest.mark.mpi
417-
def test_transformer_forward(multidevice_direct_test, benchmark):
416+
def test_transformer_forward(
417+
multidevice_direct_test, benchmark, parallelism: Parallelism
418+
):
418419
d = multidevice_direct_test.size
419420
mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d))
420421

@@ -426,7 +427,8 @@ def test_transformer_forward(multidevice_direct_test, benchmark):
426427

427428
if h % d != 0:
428429
pytest.skip(
429-
f"We only support even DID split, so the number of heads ({h}) has to be divisible by the number of GPUs ({d})."
430+
f"We only support even DID split, so the number of heads ({h}) has \
431+
to be divisible by the number of GPUs ({d})."
430432
)
431433

432434
assert e * 4 % d == 0, (
@@ -435,10 +437,17 @@ def test_transformer_forward(multidevice_direct_test, benchmark):
435437
"error. So I use `assert` instead of `pytest.skip`."
436438
)
437439

440+
if parallelism == Parallelism.SEQUENCE_PARALLEL and s % d != 0:
441+
pytest.skip(
442+
f"Sequence length {s} must be divisible by the number \
443+
of devices {d} for sequence parallelism."
444+
)
445+
438446
torch.cuda.set_device(multidevice_direct_test.local_rank)
439447

440448
# To reduce memory footprint, create unsharded data on CPU and copy only
441449
# the needed slice to GPU.
450+
inp = torch.testing.make_tensor(b, s, e, dtype=torch.bfloat16, device="cpu")
442451
mha_linear0_weight = torch.testing.make_tensor(
443452
e * 3, e, dtype=torch.bfloat16, device="cpu"
444453
)
@@ -462,7 +471,9 @@ def test_transformer_forward(multidevice_direct_test, benchmark):
462471
# arguments. They are passed in in the same order as the `define_scalar`s
463472
# and `define_tensor`s.
464473
ins = [
465-
torch.testing.make_tensor((b, s, e), dtype=torch.bfloat16, device="cuda"),
474+
inp.cuda()
475+
if parallelism == Parallelism.TENSOR_PARALLEL
476+
else multidevice_direct_test.shard_tensor(inp, 1, mesh),
466477
torch.testing.make_tensor((e,), dtype=torch.bfloat16, device="cuda"),
467478
torch.testing.make_tensor((e,), dtype=torch.bfloat16, device="cuda"),
468479
multidevice_direct_test.shard_tensor(mha_linear0_weight, 0, mesh),
@@ -479,7 +490,7 @@ def test_transformer_forward(multidevice_direct_test, benchmark):
479490

480491
with FusionDefinition() as fd:
481492
transformer_forward_definition(fd, b, s, h, e)
482-
transformer_forward_multidevice_schedule(fd, d)
493+
transformer_forward_multidevice_schedule(fd, d, parallelism)
483494

484495
warmup_fn, benchmark_fn = get_benchmark_fns(lambda: fd.execute(ins))
485496

@@ -501,21 +512,23 @@ def test_transformer_forward(multidevice_direct_test, benchmark):
501512
out,
502513
) = warmup_fn()
503514

504-
_assert_shape_dtype(layernorm0_mean, [b, s], torch.float32)
505-
_assert_shape_dtype(layernorm0_rstd, [b, s, 1], torch.float32)
515+
s_local = s // d if parallelism == Parallelism.SEQUENCE_PARALLEL else s
516+
517+
_assert_shape_dtype(layernorm0_mean, [b, s_local], torch.float32)
518+
_assert_shape_dtype(layernorm0_rstd, [b, s_local, 1], torch.float32)
506519
_assert_shape_dtype(mha_linear0_out, [b, s, e * 3 // d], torch.bfloat16)
507520
_assert_shape_dtype(sdpa_out, [b, h // d, s, e // h], torch.bfloat16)
508521
_assert_shape_dtype(sdpa_logsum_exp, [b, h // d, s], torch.float32)
509522
ref_philox_seed, ref_philox_offset = create_sdpa_rng_tensors()
510523
_assert_shape_dtype(sdpa_seed, ref_philox_seed.shape, ref_philox_seed.dtype)
511524
_assert_shape_dtype(sdpa_offset, ref_philox_offset.shape, ref_philox_offset.dtype)
512-
_assert_shape_dtype(mha_linear1_out, [b, s, e], torch.bfloat16)
513-
_assert_shape_dtype(mha_dropout_mask, [b, s, e], torch.bool)
514-
_assert_shape_dtype(layernorm1_mean, [b, s], torch.float32)
515-
_assert_shape_dtype(layernorm1_rstd, [b, s, 1], torch.float32)
525+
_assert_shape_dtype(mha_linear1_out, [b, s_local, e], torch.bfloat16)
526+
_assert_shape_dtype(mha_dropout_mask, [b, s_local, e], torch.bool)
527+
_assert_shape_dtype(layernorm1_mean, [b, s_local], torch.float32)
528+
_assert_shape_dtype(layernorm1_rstd, [b, s_local, 1], torch.float32)
516529
_assert_shape_dtype(mlp_linear0_out, [b, s, e * 4 // d], torch.bfloat16)
517-
_assert_shape_dtype(mlp_dropout_mask, [b, s, e], torch.bool)
518-
_assert_shape_dtype(out, [b, s, e], torch.bfloat16)
530+
_assert_shape_dtype(mlp_dropout_mask, [b, s_local, e], torch.bool)
531+
_assert_shape_dtype(out, [b, s_local, e], torch.bfloat16)
519532

520533
# Benchmark and profile. The profile can be collected and displayed using
521534
# `nsys`. See instructions in test_transformer_engine.py.

tests/python/multidevice/test_transformer_engine.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
import torch.distributed as dist
88
import transformer_engine.pytorch as te
9-
from benchmark_utils import get_benchmark_fns
9+
from benchmark_utils import get_benchmark_fns, Parallelism
1010
from enum import auto, Enum
1111

1212
compute_cap = torch.cuda.get_device_capability()
@@ -17,13 +17,6 @@ class ComputeType(Enum):
1717
BACKWARD = auto()
1818

1919

20-
class Parallelism(Enum):
21-
# https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/parallelisms.html#tensor-parallelism
22-
TENSOR_PARALLEL = auto()
23-
# https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/parallelisms.html#sequence-parallelism
24-
SEQUENCE_PARALLEL = auto()
25-
26-
2720
# This benchmark is instrumented with cudaProfilerStart/Stop. Therefore, one
2821
# can collect stats of the first few non-warmup benchmark iterations using
2922
# ```bash

0 commit comments

Comments
 (0)