-
Notifications
You must be signed in to change notification settings - Fork 318
[Feat] Add A Pass to Handle Negative Index #1192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,160 @@ | ||
| /*! | ||
| * \file legalize_negative_index.cc | ||
| * \brief Legalize negative indices in buffer load expressions. | ||
| */ | ||
|
|
||
| #include <tvm/ffi/reflection/registry.h> | ||
| #include <tvm/runtime/logging.h> | ||
| #include <tvm/tir/stmt_functor.h> | ||
| #include <tvm/tir/transform.h> | ||
|
|
||
| #include <unordered_map> | ||
| #include <vector> | ||
|
|
||
| #include "arith/ir_mutator_with_analyzer.h" | ||
| #include "arith/ir_visitor_with_analyzer.h" | ||
|
|
||
| namespace tvm { | ||
| namespace tl { | ||
|
|
||
| using namespace tir; | ||
| using arith::IRVisitorWithAnalyzer; | ||
|
|
||
| enum class IndexSignState { kNonNegative, kNegative, kUnknown }; | ||
|
|
||
| class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer { | ||
| public: | ||
| explicit NegativeIndexAnalyzer( | ||
| std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>> | ||
| *result) | ||
| : result_(result) {} | ||
|
|
||
| void VisitExpr_(const BufferLoadNode *op) final { | ||
| auto load = tvm::ffi::GetRef<BufferLoad>(op); | ||
| std::vector<IndexSignState> states; | ||
| states.reserve(op->indices.size()); | ||
| bool needs_record = false; | ||
|
|
||
| for (size_t i = 0; i < op->indices.size(); ++i) { | ||
| PrimExpr simplified = analyzer_.Simplify(op->indices[i]); | ||
| if (analyzer_.CanProve(simplified >= 0)) { | ||
| states.push_back(IndexSignState::kNonNegative); | ||
| continue; | ||
| } | ||
|
|
||
| if (analyzer_.CanProve(simplified < 0)) { | ||
| states.push_back(IndexSignState::kNegative); | ||
| needs_record = true; | ||
| continue; | ||
| } | ||
|
|
||
| states.push_back(IndexSignState::kUnknown); | ||
| needs_record = true; | ||
| LOG(WARNING) << "LegalizeNegativeIndex: cannot prove non-negative index " | ||
| << simplified << " for buffer " << load->buffer->name | ||
| << " (axis " << i << ")."; | ||
| } | ||
|
|
||
| if (needs_record) { | ||
| (*result_)[op] = std::move(states); | ||
| } | ||
|
|
||
| IRVisitorWithAnalyzer::VisitExpr_(op); | ||
| } | ||
|
|
||
| private: | ||
| std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>> | ||
| *result_; | ||
| }; | ||
|
|
||
| class NegativeIndexRewriter : public arith::IRMutatorWithAnalyzer { | ||
| public: | ||
| static PrimFunc | ||
| Apply(PrimFunc func, | ||
| const std::unordered_map<const BufferLoadNode *, | ||
| std::vector<IndexSignState>> &states) { | ||
| arith::Analyzer analyzer; | ||
| NegativeIndexRewriter rewriter(&analyzer, states); | ||
| if (!func->body.defined()) { | ||
| return func; | ||
| } | ||
| PrimFuncNode *func_node = func.CopyOnWrite(); | ||
| func_node->body = rewriter.VisitStmt(func_node->body); | ||
| return func; | ||
| } | ||
|
|
||
| private: | ||
| NegativeIndexRewriter( | ||
| arith::Analyzer *analyzer, | ||
| const std::unordered_map<const BufferLoadNode *, | ||
| std::vector<IndexSignState>> &states) | ||
| : arith::IRMutatorWithAnalyzer(analyzer), states_(states) {} | ||
|
|
||
| PrimExpr VisitExpr_(const BufferLoadNode *op) final { | ||
| BufferLoad load = | ||
| Downcast<BufferLoad>(arith::IRMutatorWithAnalyzer::VisitExpr_(op)); | ||
|
|
||
| auto it = states_.find(op); | ||
| if (it == states_.end()) { | ||
| return load; | ||
| } | ||
|
|
||
| auto indices = load->indices; | ||
| bool changed = false; | ||
|
|
||
| const auto &state_vector = it->second; | ||
| ICHECK_EQ(state_vector.size(), indices.size()) | ||
| << "State vector size mismatch for buffer load " << load->buffer->name; | ||
|
|
||
| for (size_t i = 0; i < indices.size(); ++i) { | ||
| if (state_vector[i] != IndexSignState::kNegative) { | ||
| continue; | ||
| } | ||
| PrimExpr extent = load->buffer->shape[i]; | ||
| indices.Set(i, analyzer_->Simplify(extent + indices[i])); | ||
| changed = true; | ||
| } | ||
|
|
||
| if (!changed) { | ||
| return load; | ||
| } | ||
|
|
||
| return BufferLoad(load->buffer, indices); | ||
| } | ||
|
|
||
| const std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>> | ||
| &states_; | ||
| }; | ||
|
|
||
| PrimFunc LegalizeNegativeIndex(PrimFunc func) { | ||
| if (!func->body.defined()) { | ||
| return func; | ||
| } | ||
|
|
||
| std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>> | ||
| states; | ||
| NegativeIndexAnalyzer analyzer(&states); | ||
| analyzer(func->body); | ||
| if (states.empty()) { | ||
| return func; | ||
| } | ||
|
|
||
| return NegativeIndexRewriter::Apply(std::move(func), states); | ||
| } | ||
|
|
||
| tvm::transform::Pass LegalizeNegativeIndexPass() { | ||
| using namespace tir::transform; | ||
| auto pass_func = [](PrimFunc f, const IRModule &, PassContext) { | ||
| return LegalizeNegativeIndex(std::move(f)); | ||
| }; | ||
| return CreatePrimFuncPass(pass_func, 0, "tl.LegalizeNegativeIndex", {}); | ||
| } | ||
|
|
||
| TVM_FFI_STATIC_INIT_BLOCK() { | ||
| namespace refl = tvm::ffi::reflection; | ||
| refl::GlobalDef().def("tl.transform.LegalizeNegativeIndex", | ||
| LegalizeNegativeIndexPass); | ||
| } | ||
|
|
||
| } // namespace tl | ||
| } // namespace tvm | ||
60 changes: 60 additions & 0 deletions
60
testing/python/language/test_tilelang_language_negative_index.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| from tilelang import tvm | ||
| import tilelang as tl | ||
| import tilelang.testing | ||
| from tvm.script import tir as T | ||
|
|
||
|
|
||
| @T.prim_func | ||
| def negative_index_before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): | ||
| T.func_attr({"tir.noalias": True}) | ||
| B[0] = A[T.int32(-1)] | ||
|
|
||
|
|
||
| @T.prim_func | ||
| def negative_index_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): | ||
| T.func_attr({"tir.noalias": True}) | ||
| B[0] = A[T.int32(15)] | ||
|
|
||
|
|
||
| @T.prim_func | ||
| def negative_index_loop_before(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")): | ||
| T.func_attr({"tir.noalias": True}) | ||
| for i in T.serial(4): | ||
| B[i] = A[-i - 1] | ||
|
|
||
|
|
||
| @T.prim_func | ||
| def negative_index_loop_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")): | ||
| T.func_attr({"tir.noalias": True}) | ||
| for i in T.serial(4): | ||
| B[i] = A[15 - i] | ||
|
|
||
|
|
||
| @T.prim_func | ||
| def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), "float32"), | ||
| B: T.Buffer((16,), "float32")): | ||
| T.func_attr({"tir.noalias": True}) | ||
| for i in T.serial(16): | ||
| B[i] = A[shift + i] | ||
|
|
||
|
|
||
| def test_legalize_negative_index_scalar(): | ||
| mod = tvm.IRModule({"main": negative_index_before}) | ||
| transformed = tl.transform.LegalizeNegativeIndex()(mod) | ||
| tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_expected.body) | ||
|
|
||
|
|
||
| def test_legalize_negative_index_affine_expr(): | ||
| mod = tvm.IRModule({"main": negative_index_loop_before}) | ||
| transformed = tl.transform.LegalizeNegativeIndex()(mod) | ||
| tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_loop_expected.body) | ||
|
|
||
|
|
||
| def test_legalize_negative_index_symbolic_passthrough(): | ||
| mod = tvm.IRModule({"main": negative_index_symbolic_before}) | ||
| transformed = tl.transform.LegalizeNegativeIndex()(mod) | ||
| tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_symbolic_before.body) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tilelang.testing.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Negative index BufferStore remains unhandled
The analyzer/rewriter only records
BufferLoadNodeindices, so a PrimFunc that writes through a negative index (e.g.A[T.int32(-1)] = value) survives this pass unchanged. Because LowerAndLegalize now runs this pass unconditionally, we still exit with negative indices on the store side, defeating the stated goal of canonicalizing indices before downstream passes.Please extend the pass to cover
BufferStoreNodeas well (analyze its indices and rewrite them inVisitStmt_). A minimal shape for the fix:and mirror the same lookup/rewrite logic in
NegativeIndexRewriter::VisitStmt_(const BufferStoreNode*).This keeps the pipeline guarantee intact and prevents negative store indices from leaking past legalization.
🤖 Prompt for AI Agents