Skip to content

indexAccumulate python api#4066

Merged
jjsjann123 merged 30 commits intomainfrom
jjsjann123/index_put_python_api
Apr 23, 2025
Merged

indexAccumulate python api#4066
jjsjann123 merged 30 commits intomainfrom
jjsjann123/index_put_python_api

Conversation

@jjsjann123
Copy link
Copy Markdown
Collaborator

@jjsjann123 jjsjann123 commented Mar 12, 2025

Things done in this PR is to support embedding backward, which requires torch.index_put_(..., accumulate=True).

Stacked PRs:

What this PR does:

  • Added python API Tensor fd.ops.index_accumulate(Tensor acc, Tensor index, Tensor value

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Mar 12, 2025

Review updated until commit 4896186

Description

  • Added Python API for index_put_accumulate

  • Included test cases for index_put_accumulate

  • Updated serialization and deserialization for IndexPutAccumulateOpRecord

  • Added reference implementation for index_put_accumulate


Changes walkthrough 📝

Relevant files
Enhancement
python_bindings.cpp
Add Python API for index_put_accumulate                                   

csrc/python_frontend/python_bindings.cpp

  • Added index_put_accumulate function to nvf_ops
  • Included documentation for index_put_accumulate
  • +42/-0   
    fusion_record.cpp
    Register parser for IndexPutAccumulateOpRecord                     

    csrc/serde/fusion_record.cpp

    • Registered parser for IndexPutAccumulateOpRecord
    +7/-0     
    fusion_record.h
    Define IndexPutAccumulateOpRecord                                               

    csrc/python_frontend/fusion_record.h

    • Defined IndexPutAccumulateOpRecord struct
    +24/-0   
    fusion_cache.fbs
    Add IndexPutAccumulateOp to RecordType                                     

    csrc/serde/fusion_cache.fbs

    • Added IndexPutAccumulateOp to RecordType enum
    +1/-0     
    Tests
    opinfo_input_generators.py
    Add test case generator for index_put_accumulate                 

    tests/python/opinfo_input_generators.py

    • Added index_put_accumulate_generator for test cases
    +19/-0   
    opinfos.py
    Add opinfo for index_put_accumulate                                           

    tests/python/opinfos.py

  • Added index_put_accumulate_opinfo to shape_ops
  • Included reference implementation index_put_accumulate_ref
  • +28/-0   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Documentation

    The docstring for index_put_accumulate could be more detailed, especially regarding the restrictions and limitations compared to the full torch.index_put(..., accumulate=true).

        Accumulates values into a tensor at specified indices.
    
        This function performs a restricted version of `torch.index_put`.
        It adds the values from `value_tv` to the elements of `acc_tv` at the indices
        specified by `index_tv`.
    
        acc_tv: The tensor to accumulate into (in-place modification).
        index_tv: The tensor containing the indices.
        value_tv: The tensor containing the values to accumulate.
    
        Returns:
            An alias to the modified `acc_tv` tensor.
    
        Note:
            This is a restricted version and may not support all features of the
            full `torch.index_put(..., accumulate=true)` function.
    )doc");
    Error Handling

    The IndexPutAccumulateOpRecord operator does not include any error handling. Consider adding checks for tensor dimensions and types to prevent runtime errors.

    void operator()(FusionState& fd) final {
      auto acc = fd.getFusionState(args_.at(0).index)->as<TensorView>();
      auto index = fd.getFusionState(args_.at(1).index)->as<TensorView>();
      auto value = fd.getFusionState(args_.at(2).index)->as<TensorView>();
    
      auto output = indexPutAccumulate(acc, index, value);
      fd.setFusionState(outputs_.at(0).index, output);
    }
    Test Coverage

    The index_put_accumulate_generator only tests a single case. Consider adding more test cases with different tensor shapes, data types, and edge cases to ensure robustness.

    def index_put_accumulate_generator(
        op: OpInfo, dtype: torch.dtype, requires_grad: bool = False, **kwargs
    ):
        make_arg = partial(
            make_tensor, device="cuda", dtype=dtype, requires_grad=requires_grad
        )
        make_index = partial(make_tensor, device="cuda", requires_grad=False)
    
        # vocab_size, hidden_size, seq_size
        cases = ((1024, 12, 300),)
    
        for vocab, hidden, seq in cases:
            for index_dtype in [torch.int, torch.long]:
                acc = make_arg((vocab, hidden))
                index = make_index((seq,), low=0, high=vocab, dtype=index_dtype)
                value = make_arg((seq, hidden))
                yield SampleInput(acc, index, value)

    @jjsjann123 jjsjann123 mentioned this pull request Mar 14, 2025
    2 tasks
    @jjsjann123 jjsjann123 marked this pull request as ready for review March 14, 2025 01:18
    @jjsjann123 jjsjann123 requested review from protonu and rdspring1 March 14, 2025 01:19
    @jjsjann123 jjsjann123 marked this pull request as draft March 14, 2025 01:19
    @jjsjann123
    Copy link
    Copy Markdown
    Collaborator Author

    marking this as draft to avoid accidental merge.
    But this PR is good for review as-is.

    Copy link
    Copy Markdown
    Collaborator

    @rdspring1 rdspring1 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Do you need to define void handle(IndexAccumulateOp* iaop) in csrc/python_frontend/translation.cpp for the python clone and segmentation features?

    Otherwise, the PR looks good to me.

    py::arg("acc"),
    py::arg("index"),
    py::arg("value"),
    py::return_value_policy::reference);
    Copy link
    Copy Markdown
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I'm trying to improve python user experience by adding a docstring to new functions.

    Docstring generated by Gemini.

        m.def("index_accumulate", &indexAccumulate,
              py::arg("acc_tv"), py::arg("index_tv"), py::arg("value_tv"),
              R"(
            Accumulates values into a tensor at specified indices.
    
            This function performs a restricted version of `torch.index_put(..., accumulate=true)`.
            It adds the values from `value_tv` to the elements of `acc_tv` at the indices
            specified by `index_tv`.
    
            acc_tv: The tensor to accumulate into (in-place modification).
            index_tv: The tensor containing the indices.
            value_tv: The tensor containing the values to accumulate.
    
            Returns:
                A pointer to the modified `acc_tv` tensor.
    
            Note:
                This is a restricted version and may not support all features of the
                full `torch.index_put(..., accumulate=true)` function.
        )");

    Copy link
    Copy Markdown
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Hahaha, thanks for the draft~~~ will add it in.

    jjsjann123 added a commit that referenced this pull request Apr 18, 2025
    Things done in this PR is to support embedding backward, which requires
    `torch.index_put_(..., accumulate=True)`.
    
    Stacked PRs:
    - [x] #4063  <-- This PR
    - [ ] #4066 
    
    What this PR does:
    * Added fusion IR node IndexPutAccumulateOp.
        1. IR signature ```TensorView* IndexPutAccumulateOp(
                  TensorView* out,
                  TensorView* acc,
                  TensorView* index,
                  TensorView* value)```
        2. Allow expression evaluation execution via `at::index_put`
    * Added c++ API ```TensorView* indexPutAccumulate(TensorView* acc_tv,
    TensorView* index_tv, TensorView* value_tv```
    
    There are two things worth noting:
    1. The signature of the op requires a `acc` input, which is supposed to
    be used by codegen later as an IO buffer. (We'll be atomic add into the
    buffer and output it directly as the same tensor). Currently the
    reference implementation doesn't do that yet and uses an out-of-place
    version instead. We'll modify that once we have proper support on
    outputs aliasing outputs.
    2. We currently reject `IndexPutAccumulateOp` from schedulers, we'll
    remove it when we have proper codegen support.
    Base automatically changed from jjsjann123/index_put to main April 18, 2025 16:53
    @jjsjann123 jjsjann123 requested a review from rdspring1 April 21, 2025 17:22
    @jjsjann123 jjsjann123 marked this pull request as ready for review April 21, 2025 17:22
    @jjsjann123
    Copy link
    Copy Markdown
    Collaborator Author

    !test

    Comment thread csrc/logical_domain_map.cpp Outdated
    py::arg("index"),
    py::arg("value"),
    py::return_value_policy::reference,
    R"doc(
    Copy link
    Copy Markdown
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    👍🏼

    Co-authored-by: Ryan Spring <rspring@nvidia.com>
    @jjsjann123
    Copy link
    Copy Markdown
    Collaborator Author

    !test

    @jjsjann123
    Copy link
    Copy Markdown
    Collaborator Author

    !test

    @jjsjann123 jjsjann123 merged commit df447be into main Apr 23, 2025
    41 of 43 checks passed
    @jjsjann123 jjsjann123 deleted the jjsjann123/index_put_python_api branch April 23, 2025 22:53
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants