Skip to content

Conversation

@Priya2698
Copy link
Collaborator

No description provided.

@github-actions
Copy link

github-actions bot commented Jan 7, 2026

Review updated until commit 2bb6353

Description

  • Implement AllToAll communication primitive in nvFuser framework

  • Add lowerToAllToAll function with mesh validation and IR creation

  • Implement postAllToAll backend function using alltoall_base

  • Add comprehensive test demonstrating AllToAll usage

Changes walkthrough

Relevant files
Enhancement
lower_to_communication.cpp
Add AllToAll lowering infrastructure                                         

csrc/host_ir/lower_to_communication.cpp

  • Add lowerToAllToAll function with mesh validation
  • Update getCommunicationInfo to handle AllToAll cases
  • Update getCommunicationLayout to include AllToAll
  • Add AllToAll case to convertSingleOpToCommunication
  • +48/-3   
    communication.cpp
    Implement AllToAll backend communication                                 

    csrc/multidevice/communication.cpp

  • Add AllToAll enum value and operator<< support
  • Update hasRoot and isReduction functions for AllToAll
  • Implement postAllToAll function with tensor reshaping logic
  • Add AllToAll case to postSingleCommunication
  • +73/-0   
    communication.h
    Add AllToAll enum value                                                                   

    csrc/multidevice/communication.h

    • Add AllToAll to CommunicationType enum
    +2/-1     
    Tests
    test_multidevice.py
    Add AllToAll test case                                                                     

    tests/python/multidevice/test_multidevice.py

  • Add test_alltoall function with comprehensive example
  • Demonstrate tensor sharding and AllToAll operation
  • Show permute operations for proper tensor layout
  • +73/-0   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    AllToAll Implementation

    The postAllToAll function implements tensor reshaping and reordering logic. The implementation assumes specific tensor shapes and may need validation for edge cases. The code reshapes input tensors and uses permute operations which could impact performance. The implementation should be verified against the c10d::Backend alltoall_base requirements and ensure proper error handling for invalid tensor configurations.

    c10::intrusive_ptr<c10d::Work> postAllToAll(
        Communication* communication,
        DeviceIdxType my_device_index,
        c10d::Backend* backend,
        at::Tensor input_tensor,
        at::Tensor output_tensor) {
      NVF_ERROR(
          isTvContiguous(communication->in()),
          "Input tensor is not contiguous: ",
          communication->in(),
          " contiguity: ",
          communication->in()->domain()->getContiguityString());
      NVF_ERROR(
          isTvContiguous(communication->out()),
          "Output tensor is not contiguous: ",
          communication->out(),
          " contiguity: ",
          communication->out()->domain()->getContiguityString());
    
      // input_tv = [DIDx(d), n/d, m, ...]
      // output_tv = [n, DIDx(d), m/d, ...]
      // `n`: gathered dimension
      // `m`: scattered dimension
      // For alltoall correctness, we split `m` and reorder as [DIDx(d), d, n/d,
      // m/d, ...] such that alltoall_base splits across the `d` dimension.
    
      int64_t d = communication->team_size();
      auto input_sizes = input_tensor.sizes();
    
      NVF_CHECK(
          input_sizes.at(1) % d == 0,
          "Scattered dimension must be divisible by the team size");
    
      std::vector<int64_t> input_reshape_sizes(
          input_sizes.begin(), input_sizes.end());
      input_reshape_sizes.at(1) = d;
      input_reshape_sizes.insert(
          input_reshape_sizes.begin() + 2, input_sizes.at(1) / d);
      auto reshaped_input = input_tensor.reshape(input_reshape_sizes);
    
      std::vector<int64_t> permute_dims(input_reshape_sizes.size());
      std::iota(permute_dims.begin(), permute_dims.end(), 0);
      std::swap(permute_dims[0], permute_dims[1]);
    
      auto reordered_input = reshaped_input.permute(permute_dims).contiguous();
    
      auto flattened_input_tensor = viewAsCompact(reordered_input);
      auto flattened_output_tensor = viewAsCompact(output_tensor);
    
      // alltoall_base requires even splits of the input and output tensors.
      auto input_splits = at::tensor_split(
          flattened_input_tensor, communication->team_size(), /*dim=*/0);
      auto output_splits = at::tensor_split(
          flattened_output_tensor, communication->team_size(), /*dim=*/0);
      assertBuffersHaveSameSize(input_splits, output_splits);
    
      std::vector<int64_t> empty_split_sizes;
      return backend->alltoall_base(
          flattened_output_tensor,
          flattened_input_tensor,
          empty_split_sizes,
          empty_split_sizes,
          /*options=*/{});
    }
    Communication Detection Logic

    The getCommunicationInfo function was modified to add AllToAll detection logic. The condition checks if c_logical_id matches p2c_map.at(p_logical_id) to decide between SendRecv and AllToAll. This logic should be validated to ensure it correctly identifies AllToAll scenarios and doesn't introduce false positives or negatives in communication type detection.

    if (c_logical_id == p2c_map.at(p_logical_id)) {
      fill_communication_info(
          CommunicationType::SendRecv, p_logical_id, c_logical_id);
    } else {
      fill_communication_info(
          CommunicationType::AllToAll, nullptr, nullptr);
    }
    Test Complexity

    The test_alltoall function uses complex tensor permutations and reshaping operations. While the test appears comprehensive, the complexity of the test setup (using permute operations to work around allocation domain constraints) suggests potential limitations in the current AllToAll implementation that should be documented or addressed. The test should be reviewed to ensure it adequately validates the core AllToAll functionality without being overly dependent on specific tensor layout manipulations.

    @pytest.mark.mpi
    def test_alltoall(multidevice_test):
        d = multidevice_test.size
        mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d))
        k, m, n = 3, 40, 56
    
        assert (
            m % d == 0 and n % d == 0
        ), "Sharded dimensions must be divisible by the team size"
    
        with FusionDefinition() as fd:
            inp = fd.define_tensor((k, m, n), contiguity=True, dtype=DataType.Half)
            # Ideally, we could have exposed and split the allocation domain of
            # alltoall input and output, as follows:
            #   Input of shape [k, m, n] sharded on n
            #   Output of shape [k, m, n] sharded on m
            #
            # Where d is an outer split from `n` and d* is an outer_split from `m`:
            #   input           [b, m, DIDx(d), n/d]
            #   alltoall_input  [d*, b, m/d*, DIDx(d), n/d]
            #     d* is reordered outermost. `alltoall_single` will slice on the
            #     outermost dimension.
            #   alltoall_output [d, b, DIDx(d*), m/d*, n/d]
            #     d is gathered back. AllToAll places data from each device in
            #     memory in order of rank.
            #   output          [b, DIDx(d*), m/d, n]
            #     Same as sharding original input along row dimension instead of
            #     column.
            #
            # However, tensors corresponding to allocation domains with non-adjacent
            # splits fail stride validation. To avoid this, I permute the gathered
            # and scattered dimensions to be outermost in that order. `postAllToAll`
            # will split the scattered dimension and reorder as
            # [DIDx(d), d, n/d, m/d, k] and make it contiguous. The output will be
            # [DIDx(d*), n, m, k]. This avoids the non-adjacent split in output and
            # avoids exposing the non-adjacent split of `m` in input to the fusion
            # definition. I am using `permute` instead of setting the allocation
            # domain to avoid another copy.
            all2all_inp = fd.ops.permute(inp, dims=[2, 1, 0])
            all2all_out = fd.ops.set(all2all_inp)
            out = fd.ops.permute(all2all_out, dims=[2, 1, 0])
            fd.add_output(all2all_inp)
            fd.add_output(all2all_out)
            fd.add_output(out)
    
            inp.set_device_mesh(mesh)
            inp.outer_split(2, d)
            inp.axis(2).parallelize(nvfuser.ParallelType.mesh_x)
    
            all2all_inp.set_device_mesh(mesh)
            all2all_inp.outer_split(0, d)
            all2all_inp.axis(0).parallelize(
                nvfuser.ParallelType.mesh_x
            )  # [DIDx(d), n, d * m, k]
    
            all2all_out.set_device_mesh(mesh)
            all2all_out.outer_split(1, d)
            all2all_out.axis(1).parallelize(
                nvfuser.ParallelType.mesh_x
            )  # [d * n, DIDx(d), m, k]
            all2all_out.set_allocation_domain(all2all_out.get_loop_domain(), True)
    
            out.set_device_mesh(mesh)
            out.outer_split(1, d)
            out.axis(1).parallelize(nvfuser.ParallelType.mesh_x)
            out.set_allocation_domain(out.get_loop_domain(), True)
    
        in_tensor = torch.randn(k, m, n, dtype=torch.float16)
        sharded = multidevice_test.shard_tensor(in_tensor, 2, mesh)
        (all2all_inp, all2all_out, out) = fd.execute([sharded])
        torch.testing.assert_close(out, multidevice_test.shard_tensor(in_tensor, 1, mesh))

    @Priya2698
    Copy link
    Collaborator Author

    !test

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants