indexAccumulate python api#4066
Conversation
|
Review updated until commit 4896186 Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
|
marking this as draft to avoid accidental merge. |
| py::arg("acc"), | ||
| py::arg("index"), | ||
| py::arg("value"), | ||
| py::return_value_policy::reference); |
There was a problem hiding this comment.
I'm trying to improve python user experience by adding a docstring to new functions.
Docstring generated by Gemini.
m.def("index_accumulate", &indexAccumulate,
py::arg("acc_tv"), py::arg("index_tv"), py::arg("value_tv"),
R"(
Accumulates values into a tensor at specified indices.
This function performs a restricted version of `torch.index_put(..., accumulate=true)`.
It adds the values from `value_tv` to the elements of `acc_tv` at the indices
specified by `index_tv`.
acc_tv: The tensor to accumulate into (in-place modification).
index_tv: The tensor containing the indices.
value_tv: The tensor containing the values to accumulate.
Returns:
A pointer to the modified `acc_tv` tensor.
Note:
This is a restricted version and may not support all features of the
full `torch.index_put(..., accumulate=true)` function.
)");There was a problem hiding this comment.
Hahaha, thanks for the draft~~~ will add it in.
Things done in this PR is to support embedding backward, which requires `torch.index_put_(..., accumulate=True)`. Stacked PRs: - [x] #4063 <-- This PR - [ ] #4066 What this PR does: * Added fusion IR node IndexPutAccumulateOp. 1. IR signature ```TensorView* IndexPutAccumulateOp( TensorView* out, TensorView* acc, TensorView* index, TensorView* value)``` 2. Allow expression evaluation execution via `at::index_put` * Added c++ API ```TensorView* indexPutAccumulate(TensorView* acc_tv, TensorView* index_tv, TensorView* value_tv``` There are two things worth noting: 1. The signature of the op requires a `acc` input, which is supposed to be used by codegen later as an IO buffer. (We'll be atomic add into the buffer and output it directly as the same tensor). Currently the reference implementation doesn't do that yet and uses an out-of-place version instead. We'll modify that once we have proper support on outputs aliasing outputs. 2. We currently reject `IndexPutAccumulateOp` from schedulers, we'll remove it when we have proper codegen support.
|
!test |
| py::arg("index"), | ||
| py::arg("value"), | ||
| py::return_value_policy::reference, | ||
| R"doc( |
Co-authored-by: Ryan Spring <rspring@nvidia.com>
|
!test |
|
!test |
Things done in this PR is to support embedding backward, which requires
torch.index_put_(..., accumulate=True).Stacked PRs:
What this PR does:
Tensor fd.ops.index_accumulate(Tensor acc, Tensor index, Tensor value