Skip to content

Conversation

@Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Nov 20, 2025

Wall-clock time measured on 8 80GB H100 nodes:

TE nvFuser
2.5 ms 2.1 ms

@github-actions
Copy link

github-actions bot commented Nov 20, 2025

Review updated until commit 68e8a9a

Description

  • Add sequence parallelism support for transformer forward pass

  • Implement sequence dimension splitting and mesh parallelization

  • Add parametrized testing for both tensor and sequence parallelism

  • Update shape assertions to handle sequence-parallel tensor dimensions

  • Move Parallelism enum to shared benchmark_utils for reuse

Changes walkthrough

Relevant files
Enhancement
benchmark_utils.py
Add Parallelism enum to benchmark utilities                           

tests/python/multidevice/benchmark_utils.py

  • Add Parallelism enum with TENSOR_PARALLEL and SEQUENCE_PARALLEL
    options
  • Move enum definition from test_transformer_engine.py to shared utility
  • +8/-0     
    test_transformer.py
    Implement sequence parallel transformer forward testing   

    tests/python/multidevice/test_transformer.py

  • Add sequence parallelism logic with outer_split and mesh_x
    parallelization
  • Parametrize test for both tensor and sequence parallelism modes
  • Update input tensor creation to handle sequence-parallel sharding
  • Modify shape assertions for sequence-divided dimensions
  • Add skip condition for incompatible sequence length/device count
  • +43/-30 
    Refactoring
    test_transformer_engine.py
    Remove duplicate Parallelism enum definition                         

    tests/python/multidevice/test_transformer_engine.py

  • Remove duplicate Parallelism enum definition
  • Import Parallelism from shared benchmark_utils module
  • +1/-8     

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    🔒 No security concerns identified
    ⚡ Recommended focus areas for review
    Sequence Parallelism Implementation

    The sequence parallelism implementation appears correct with proper tensor sharding and dimension handling. The input tensor is correctly sharded along sequence dimension (axis 1) and local sequence length calculations look accurate. However, verify that all intermediate tensor shapes and operations are properly handled for sequence parallelism, particularly for operations that span sequence dimensions.

    if parallelism == Parallelism.SEQUENCE_PARALLEL and s % d != 0:
        pytest.skip(
            f"Sequence length {s} must be divisible by the number \
                    of devices {d} for sequence parallelism."
        )
    Performance Validation

    The PR shows promising performance results (2.1ms vs 2.5ms for TE), but ensure comprehensive benchmarking across different model sizes and sequence lengths. Validate that sequence parallelism provides consistent benefits and doesn't introduce regressions for smaller models or different parallelism configurations.

    benchmark.pedantic(benchmark_fn, rounds=5)

    @Priya2698 Priya2698 changed the base branch from main to pm/decompose_linear November 20, 2025 14:28
    @Priya2698 Priya2698 changed the base branch from pm/decompose_linear to main November 20, 2025 14:28
    @Priya2698 Priya2698 requested a review from wujingyue December 4, 2025 23:49
    @Priya2698 Priya2698 marked this pull request as ready for review December 4, 2025 23:49
    @Priya2698
    Copy link
    Collaborator Author

    !build

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Dec 4, 2025

    Greptile Overview

    Greptile Summary

    This PR adds sequence parallelism support to the transformer forward pass test, complementing the existing tensor parallelism.

    • Moved Parallelism enum from test_transformer_engine.py to benchmark_utils.py for reuse
    • Extended transformer_forward_multidevice_schedule() to handle sequence parallelism by splitting input along the sequence dimension
    • Parametrized test_transformer_forward to test both tp and sp modes
    • Added proper divisibility check for sequence length with number of devices
    • Updated shape assertions to use s_local (local sequence length) for sequence-parallelized outputs

    The benchmarks show nvFuser achieves 2.1ms vs TransformerEngine's 2.5ms on 8 H100 nodes.

    Confidence Score: 5/5

    • This PR is safe to merge - it adds well-structured test coverage for sequence parallelism without modifying core functionality.
    • The changes are limited to test files, follow existing patterns in the codebase, and include proper validation checks. The sequence parallelism implementation mirrors the established tensor parallelism pattern with correct dimension handling.
    • No files require special attention.

    Important Files Changed

    File Analysis

    Filename Score Overview
    tests/python/multidevice/benchmark_utils.py 5/5 Added Parallelism enum (TENSOR_PARALLEL, SEQUENCE_PARALLEL) for reuse across test files. Clean refactoring with no issues.
    tests/python/multidevice/test_transformer.py 5/5 Added sequence parallel support to transformer forward test with correct input sharding on dimension 1, proper divisibility checks, and accurate local shape assertions for sequence-parallelized outputs.
    tests/python/multidevice/test_transformer_engine.py 5/5 Removed duplicate Parallelism enum definition, now imports from benchmark_utils.py. No functional changes.

    Sequence Diagram

    sequenceDiagram
        participant Test as test_transformer_forward
        participant FD as FusionDefinition
        participant Schedule as multidevice_schedule
        participant Exec as fd.execute
    
        Test->>Test: Select parallelism (tp/sp)
        Test->>Test: Create input tensor (CPU)
        
        alt TENSOR_PARALLEL
            Test->>Test: inp.cuda() (full tensor)
        else SEQUENCE_PARALLEL
            Test->>Test: shard_tensor(inp, dim=1) (split sequence)
        end
        
        Test->>FD: transformer_forward_definition(b, s, h, e)
        Test->>Schedule: transformer_forward_multidevice_schedule(fd, d, parallelism)
        Schedule->>Schedule: set_device_mesh for all inputs
        
        alt SEQUENCE_PARALLEL
            Schedule->>Schedule: inp.outer_split(1, num_devices)
            Schedule->>Schedule: inp.axis(1).parallelize(mesh_x)
        end
        
        Schedule->>Schedule: Split weight tensors for tensor parallelism
        
        Test->>Exec: fd.execute(ins)
        Exec-->>Test: outputs (with s_local for SP mode)
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile



    def transformer_forward_multidevice_schedule(fd: FusionDefinition, num_devices: int):
    def transformer_forward_multidevice_schedule(
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    With direct bindings, the schedule doesn't have to live in a separate function from the definition. This way, you can remove the input-unpacking code. (This could have been addressed when the test was migrated from legacy bindings to direct bindings, but I didn’t push for it at the time to keep the migration moving quickly.)

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Got it.
    I prefer keeping it separate. It's easier to focus on just the schedule (with the downside on unpacking) as the definition is quite long.

    @Priya2698
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @Priya2698
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @Priya2698 Priya2698 merged commit e224b26 into main Dec 5, 2025
    28 of 30 checks passed
    @Priya2698 Priya2698 deleted the pm/transformer_sp_forward branch December 5, 2025 18:41
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants