1212 create_sdpa_rng_tensors ,
1313 is_pre_ampere ,
1414)
15- from benchmark_utils import get_benchmark_fns
15+ from benchmark_utils import get_benchmark_fns , Parallelism
1616
1717
1818@pytest .mark .mpi
@@ -349,24 +349,12 @@ def transformer_forward_definition(
349349 fd .add_output (out )
350350
351351
352- def transformer_forward_multidevice_schedule (fd : FusionDefinition , num_devices : int ):
352+ def transformer_forward_multidevice_schedule (
353+ fd : FusionDefinition , num_devices : int , parallelism : Parallelism
354+ ):
353355 mesh = nvfuser .multidevice .DeviceMesh (range (num_devices ))
354356 inputs = fd .fusion .inputs ()
355- inp = inputs [0 ]
356- layernorm0_weight = inputs [1 ]
357- layernorm0_bias = inputs [2 ]
358- mha_linear0_weight = inputs [3 ]
359- mha_linear0_bias = inputs [4 ]
360- mha_linear1_weight = inputs [5 ]
361- mha_linear1_bias = inputs [6 ]
362- layernorm1_weight = inputs [7 ]
363- layernorm1_bias = inputs [8 ]
364- mlp_linear0_weight = inputs [9 ]
365- mlp_linear0_bias = inputs [10 ]
366- mlp_linear1_weight = inputs [11 ]
367- mlp_linear1_bias = inputs [12 ]
368-
369- for tv in [
357+ (
370358 inp ,
371359 layernorm0_weight ,
372360 layernorm0_bias ,
@@ -380,9 +368,15 @@ def transformer_forward_multidevice_schedule(fd: FusionDefinition, num_devices:
380368 mlp_linear0_bias ,
381369 mlp_linear1_weight ,
382370 mlp_linear1_bias ,
383- ]:
371+ ) = inputs
372+
373+ for tv in inputs :
384374 tv .set_device_mesh (mesh )
385375
376+ if parallelism == Parallelism .SEQUENCE_PARALLEL :
377+ inp .outer_split (1 , num_devices )
378+ inp .axis (1 ).parallelize (nvfuser .ParallelType .mesh_x )
379+
386380 for tv in [
387381 mha_linear0_weight ,
388382 mha_linear0_bias ,
@@ -413,8 +407,15 @@ def _assert_shape_dtype(
413407 is_pre_ampere (),
414408 reason = "Flash Attention is only supported on Ampere and newer devices." ,
415409)
410+ @pytest .mark .parametrize (
411+ "parallelism" ,
412+ [Parallelism .TENSOR_PARALLEL , Parallelism .SEQUENCE_PARALLEL ],
413+ ids = ["tp" , "sp" ],
414+ )
416415@pytest .mark .mpi
417- def test_transformer_forward (multidevice_direct_test , benchmark ):
416+ def test_transformer_forward (
417+ multidevice_direct_test , benchmark , parallelism : Parallelism
418+ ):
418419 d = multidevice_direct_test .size
419420 mesh = nvfuser .multidevice .DeviceMesh (torch .arange (d ))
420421
@@ -426,7 +427,8 @@ def test_transformer_forward(multidevice_direct_test, benchmark):
426427
427428 if h % d != 0 :
428429 pytest .skip (
429- f"We only support even DID split, so the number of heads ({ h } ) has to be divisible by the number of GPUs ({ d } )."
430+ f"We only support even DID split, so the number of heads ({ h } ) has \
431+ to be divisible by the number of GPUs ({ d } )."
430432 )
431433
432434 assert e * 4 % d == 0 , (
@@ -435,10 +437,17 @@ def test_transformer_forward(multidevice_direct_test, benchmark):
435437 "error. So I use `assert` instead of `pytest.skip`."
436438 )
437439
440+ if parallelism == Parallelism .SEQUENCE_PARALLEL and s % d != 0 :
441+ pytest .skip (
442+ f"Sequence length { s } must be divisible by the number \
443+ of devices { d } for sequence parallelism."
444+ )
445+
438446 torch .cuda .set_device (multidevice_direct_test .local_rank )
439447
440448 # To reduce memory footprint, create unsharded data on CPU and copy only
441449 # the needed slice to GPU.
450+ inp = torch .testing .make_tensor (b , s , e , dtype = torch .bfloat16 , device = "cpu" )
442451 mha_linear0_weight = torch .testing .make_tensor (
443452 e * 3 , e , dtype = torch .bfloat16 , device = "cpu"
444453 )
@@ -462,7 +471,9 @@ def test_transformer_forward(multidevice_direct_test, benchmark):
462471 # arguments. They are passed in in the same order as the `define_scalar`s
463472 # and `define_tensor`s.
464473 ins = [
465- torch .testing .make_tensor ((b , s , e ), dtype = torch .bfloat16 , device = "cuda" ),
474+ inp .cuda ()
475+ if parallelism == Parallelism .TENSOR_PARALLEL
476+ else multidevice_direct_test .shard_tensor (inp , 1 , mesh ),
466477 torch .testing .make_tensor ((e ,), dtype = torch .bfloat16 , device = "cuda" ),
467478 torch .testing .make_tensor ((e ,), dtype = torch .bfloat16 , device = "cuda" ),
468479 multidevice_direct_test .shard_tensor (mha_linear0_weight , 0 , mesh ),
@@ -479,7 +490,7 @@ def test_transformer_forward(multidevice_direct_test, benchmark):
479490
480491 with FusionDefinition () as fd :
481492 transformer_forward_definition (fd , b , s , h , e )
482- transformer_forward_multidevice_schedule (fd , d )
493+ transformer_forward_multidevice_schedule (fd , d , parallelism )
483494
484495 warmup_fn , benchmark_fn = get_benchmark_fns (lambda : fd .execute (ins ))
485496
@@ -501,21 +512,23 @@ def test_transformer_forward(multidevice_direct_test, benchmark):
501512 out ,
502513 ) = warmup_fn ()
503514
504- _assert_shape_dtype (layernorm0_mean , [b , s ], torch .float32 )
505- _assert_shape_dtype (layernorm0_rstd , [b , s , 1 ], torch .float32 )
515+ s_local = s // d if parallelism == Parallelism .SEQUENCE_PARALLEL else s
516+
517+ _assert_shape_dtype (layernorm0_mean , [b , s_local ], torch .float32 )
518+ _assert_shape_dtype (layernorm0_rstd , [b , s_local , 1 ], torch .float32 )
506519 _assert_shape_dtype (mha_linear0_out , [b , s , e * 3 // d ], torch .bfloat16 )
507520 _assert_shape_dtype (sdpa_out , [b , h // d , s , e // h ], torch .bfloat16 )
508521 _assert_shape_dtype (sdpa_logsum_exp , [b , h // d , s ], torch .float32 )
509522 ref_philox_seed , ref_philox_offset = create_sdpa_rng_tensors ()
510523 _assert_shape_dtype (sdpa_seed , ref_philox_seed .shape , ref_philox_seed .dtype )
511524 _assert_shape_dtype (sdpa_offset , ref_philox_offset .shape , ref_philox_offset .dtype )
512- _assert_shape_dtype (mha_linear1_out , [b , s , e ], torch .bfloat16 )
513- _assert_shape_dtype (mha_dropout_mask , [b , s , e ], torch .bool )
514- _assert_shape_dtype (layernorm1_mean , [b , s ], torch .float32 )
515- _assert_shape_dtype (layernorm1_rstd , [b , s , 1 ], torch .float32 )
525+ _assert_shape_dtype (mha_linear1_out , [b , s_local , e ], torch .bfloat16 )
526+ _assert_shape_dtype (mha_dropout_mask , [b , s_local , e ], torch .bool )
527+ _assert_shape_dtype (layernorm1_mean , [b , s_local ], torch .float32 )
528+ _assert_shape_dtype (layernorm1_rstd , [b , s_local , 1 ], torch .float32 )
516529 _assert_shape_dtype (mlp_linear0_out , [b , s , e * 4 // d ], torch .bfloat16 )
517- _assert_shape_dtype (mlp_dropout_mask , [b , s , e ], torch .bool )
518- _assert_shape_dtype (out , [b , s , e ], torch .bfloat16 )
530+ _assert_shape_dtype (mlp_dropout_mask , [b , s_local , e ], torch .bool )
531+ _assert_shape_dtype (out , [b , s_local , e ], torch .bfloat16 )
519532
520533 # Benchmark and profile. The profile can be collected and displayed using
521534 # `nsys`. See instructions in test_transformer_engine.py.
0 commit comments