Skip to content

Conversation

@Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Nov 20, 2025

Some of the created tensorviews were not sharded consistently and hence led to more communication than needed.

@github-actions
Copy link

github-actions bot commented Nov 20, 2025

Review updated until commit 97af3bf

Description

  • Fix decomposeLinearWithBias to shard all created tensorviews consistently

  • Add backpropagation of shardings to intermediate expressions between bias and output

  • Update broadcast mapping in propagation to handle dimensions properly

  • Enhance test coverage with profiler validation for communication kernels

Changes walkthrough

Relevant files
Bug fix
decompose_reshardings.cpp
Fix sharding consistency in decomposeLinearWithBias           

csrc/preseg_passes/decompose_reshardings.cpp

  • Remove redundant selfReplay calls
  • Add backpropagation loop to shard intermediate expressions
    consistently
  • Use shardLoopLike with backward propagation for proper sharding
  • Ensure all tensorviews between bias and output are properly sharded
  • +22/-2   
    Enhancement
    propagation.cpp
    Enhance broadcast mapping in propagation                                 

    csrc/multidevice/propagation.cpp

  • Add mapBroadcast(false) to forward direction mapping
  • Add mapBroadcast(false) to backward direction mapping
  • Improve broadcast dimension handling in logical domain mapping
  • +6/-2     
    communication_executor.cpp
    Update communication executor profiler                                     

    csrc/runtime/communication_executor.cpp

  • Change profiler scheduler type from ExprEval to Communication
  • Update kernel profiling to reflect communication operations
  • +1/-1     
    Tests
    test_matmul.py
    Enhance linear reduce scatter test with bias and profiling

    tests/python/multidevice/test_matmul.py

  • Update test_linear_reduce_scatter with bias parameter and bfloat16
  • Add profiler validation for communication kernel scheduling
  • Modify tensor dimensions and initialization for better test coverage
  • Ensure single communication kernel execution
  • +23/-21 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Complex Sharding Logic

    The new backpropagation loop for sharding intermediate expressions is complex and could potentially miss edge cases. The nested loops over expressions and tensorviews need careful validation to ensure all intermediate tensorviews are properly sharded without creating inconsistencies.

    for (Expr* expr : StmtSort::getExprsBetween(
                          {without_bias, broadcasted_bias}, {new_out}) |
             std::views::reverse) {
      for (auto* output : ir_utils::filterByType<TensorView>(expr->outputs())) {
        for (auto* input : ir_utils::filterByType<TensorView>(expr->inputs())) {
          shardLoopLike(
              /*ref=*/output,
              /*target=*/input,
              deviceAndStreamParallelTypes(),
              PropagateDirection::kBackward);
        }
      }
    }
    TransformReplay Calls

    Multiple TransformReplay::selfReplay calls are added which could impact performance. The order and necessity of these replay operations should be validated to ensure they don't introduce unnecessary computational overhead.

    TransformReplay::selfReplay(out->domain(), without_bias->domain());
    TransformReplay::selfReplay(out->domain(), new_out->domain());

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Nov 20, 2025

    Greptile Overview

    Greptile Summary

    Fixed inconsistent sharding in decomposeLinearWithBias that caused unnecessary communication operations. The change ensures all intermediate tensorviews created during the decomposition are properly sharded by adding backward propagation logic and preventing broadcast dimensions from being mapped during sharding propagation.

    Key improvements:

    • Reordered TransformReplay calls to apply after IR graph modification
    • Added backward propagation loop through intermediate expressions to consistently shard all tensorviews between bias and output
    • Modified getRef2TargetMap to exclude broadcast dimensions (.mapBroadcast(false)) preventing incorrect dimension mapping
    • Fixed profiler to correctly report communication scheduler type
    • Enhanced test to validate only one reduce-scatter operation is scheduled, confirming the fix reduces communication overhead

    Confidence Score: 4/5

    • This PR is safe to merge with minor verification recommended
    • The changes correctly address the sharding consistency issue with a well-structured solution. The backward propagation logic follows existing patterns in the codebase, and the test enhancement validates the fix reduces communication overhead. Score reflects solid implementation with comprehensive test coverage, though additional validation on diverse model architectures would provide extra confidence
    • No files require special attention

    Important Files Changed

    File Analysis

    Filename Score Overview
    csrc/preseg_passes/decompose_reshardings.cpp 4/5 Fixed sharding propagation in decomposeRowParallelLinearWithBias by adding backward propagation loop and reordering TransformReplay calls to ensure all intermediate tensorviews are consistently sharded
    csrc/multidevice/propagation.cpp 5/5 Added .mapBroadcast(false) to PairwiseLogicalDomainMap calls in getRef2TargetMap to prevent broadcast dimensions from being incorrectly mapped during sharding propagation
    csrc/runtime/communication_executor.cpp 5/5 Corrected scheduler type from ExprEval to Communication in profiler for communication kernels
    tests/python/multidevice/test_matmul.py 5/5 Enhanced test_linear_reduce_scatter to validate correct sharding by adding bias parameter, profiling, and asserting only one communication kernel is scheduled

    Sequence Diagram

    sequenceDiagram
        participant User as User Code
        participant Linear as decomposeRowParallelLinearWithBias
        participant Fusion as Fusion IR
        participant Propagate as shardLoopLike
        participant Map as PairwiseLogicalDomainMap
        
        User->>Linear: linear_with_bias operation
        Linear->>Fusion: Create without_bias = linear(A, B)
        Linear->>Fusion: Create broadcasted_bias
        Linear->>Fusion: Create new_out = add(without_bias, broadcasted_bias)
        Linear->>Fusion: Replace old out with new_out
        
        Note over Linear: Apply sharding transformations
        Linear->>Fusion: TransformReplay on without_bias
        Linear->>Fusion: TransformReplay on new_out
        
        Note over Linear: Backward propagate shardings
        loop For each expr (reverse order)
            loop For each output TV
                loop For each input TV
                    Linear->>Propagate: shardLoopLike(output, input, kBackward)
                    Propagate->>Map: getRef2TargetMap with mapBroadcast(false)
                    Map-->>Propagate: Domain mapping (excluding broadcasts)
                    Propagate->>Fusion: Apply sharding to input TV
                end
            end
        end
        
        Linear-->>User: Consistently sharded computation graph
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    group_id_);
    SegmentProfiler& sprof = FusionProfiler::segment(group_id_);
    sprof.inputBytesAccessed(computeBytes(args));
    sprof.scheduler(toString(SchedulerType::ExprEval));
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Caused the wrong scheduler name in profiler output.

    @Priya2698
    Copy link
    Collaborator Author

    !test


    (out,) = fd.execute([inp, weight])
    with PythonProfiler() as prof:
    (out,) = fd.execute([inp, weight, bias.cuda()])
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Does this synchronize? Could we miss kernels?

    Copy link
    Collaborator Author

    @Priya2698 Priya2698 Nov 21, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    fd.execute should not return until kernels have completed.
    There is a cudaStreamSynchronize at the end of nsys trace too.

    Is this what you are referring to?

    Copy link
    Collaborator

    @wujingyue wujingyue Nov 21, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    There's a difference between cudaStreamSynchronize and cudaDeviceSynchronize though. The former blocks the stream and the latter blocks the host.

    Copy link
    Collaborator Author

    @Priya2698 Priya2698 Nov 21, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    You're right. I assumed cudaStreamSynchronize would be enough here but pointwise kernel and nccl call on different streams.

    FusionProfiler/PythonProfiler synchronize at start but not on stop. So I will add an explicit call here.

    Note for myself: See if FusionProfiler should synchronize before reading data.

    @Priya2698
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    4 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @Priya2698
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    4 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @Priya2698 Priya2698 merged commit 5d8efce into main Dec 4, 2025
    59 of 60 checks passed
    @Priya2698 Priya2698 deleted the pm/decompose_linear branch December 4, 2025 23:30
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants