Skip to content

Conversation

@Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Oct 4, 2025

The previous implementation ensured that allocation and loop domains were identical due to limitations in stack (#4381). With recent changes, we are able to support different allocation and loop domains, and this reordering is unnecessary.

This PR reorders device IDs to the front of the loop domain. They are at the position of origin in allocation domain.

Following PRs will update this pass for Stream parallel type.

Benchmark results: No notable difference.

This table compares the maximum of the minimum time across all ranks for the Transformer forward and backward passes. Results are for 8 H100 GPUs.

Test Main PR
Forward (test_transformer_forward) 2.2129 2.2107
Backward (test_transformer_backward) 4.1178 4.1033

@github-actions
Copy link

github-actions bot commented Oct 4, 2025

Review updated until commit 2e4e25b

Description

  • Fix vectorization validation by ignoring device dimensions

  • Update sharded ID lookup to skip reduction domains

  • Remove redundant loop-allocation domain reordering

  • Strengthen error checking for logical axis mapping


Changes walkthrough 📝

Relevant files
Bug fix
validation.cpp
Skip device dims in vectorization validation                         

csrc/device_lower/validation.cpp

  • Skip device dimensions during vectorization validation
  • Prevent device dims from affecting contiguity checks
  • +1/-1     
    Enhancement
    utils.cpp
    Improve sharded ID lookup and error handling                         

    csrc/multidevice/utils.cpp

  • Use kNoReductions filter when searching sharded IDs
  • Replace index-based loop with direct domain iteration
  • Add NVF_ERROR for missing producing logical axis
  • +5/-4     
    finalize_multidevice_domains.cpp
    Simplify allocation domain finalization                                   

    csrc/preseg_passes/finalize_multidevice_domains.cpp

  • Remove redundant loop domain reordering
  • Simplify allocation domain setup
  • Directly apply allocation domain without permutation
  • Call reorderParallelizedToFront unconditionally
  • +2/-33   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Logic Change

    The condition in the loop domain validation has been extended to skip device dimensions, which may affect contiguity checks. This change should be validated to ensure it does not inadvertently skip meaningful non-reduction, non-broadcast dimensions that are not device-related.

    if (r_id->isReduction() || r_id->isBroadcast() || r_id->isDeviceDim()) {
      continue;
    Error Handling

    The logic for handling cases where the producing logical axis is not found has changed from continuing the loop to throwing an error. This stricter behavior should be reviewed to confirm it is safe across all calling contexts and does not break existing valid use cases.

    NVF_ERROR(
        sharded_axis != -1,
        "Producing logical axis not found for ",
        sharded_id);
    Allocation-Loop Domain Mismatch

    The PR removes logic that enforces alignment between allocation and loop domains via permutation and reordering. While this is intentional, the removal of safety checks and the new assumption of independent domain ordering should be carefully validated, especially for resharding cases.

      tv->setAllocationDomain(new_allocation_domain, new_contiguity);
      reorderParallelizedToFront(tv);
    }

    @Priya2698
    Copy link
    Collaborator Author

    !test --diff

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698 Priya2698 requested a review from wujingyue October 7, 2025 05:02
    @Priya2698 Priya2698 marked this pull request as ready for review October 7, 2025 05:02
    }();

    for (auto&& [index, id] : enumerate(domain)) {
    for (auto&& [index, id] : enumerate(domain | TensorDomain::kNoReductions)) {
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    For reduce-scatter outputs, we may return the rDIDx in unshardedSizes if it is ordered before iDIDx in the loop domain. This leads to incorrect shape deduction. For e.g.: LowerCollectiveTest.NoncontigReduceScatter

    for (size_t i = tv->getMaybeAllocationDomain().size(); i > 0; i--) {
    auto r_id = tv->getMaybeAllocationDomain()[i - 1];
    if (r_id->isReduction() || r_id->isBroadcast()) {
    if (r_id->isReduction() || r_id->isBroadcast() || r_id->isDeviceDim()) {
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    DIDx is not reordered to the front in allocation domain and may appear in the innermost position. For e..g. test_matmul.test_linear_reduce_scatter (the copy kernel is pointwise scheduled and codegen-ed).

    Copy link
    Collaborator

    @wujingyue wujingyue left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Wonderful!

    @Priya2698
    Copy link
    Collaborator Author

    !test --diff

    @Priya2698 Priya2698 merged commit 4b8dd52 into main Oct 7, 2025
    59 of 61 checks passed
    @Priya2698 Priya2698 deleted the pm/alloc_order branch October 7, 2025 23:40
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants