-
Notifications
You must be signed in to change notification settings - Fork 314
[Feat] Add A Pass to Handle Negative Index #1192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughA new transform pass is introduced to legalize negative indices in buffer operations. The implementation includes an analyzer to detect negative indices, a rewriter to convert them to non-negative equivalents, and integration into the transformation pipeline with corresponding test coverage. Changes
Sequence DiagramsequenceDiagram
participant PrimFunc as PrimFunc
participant Analyzer as NegativeIndexAnalyzer
participant Rewriter as NegativeIndexRewriter
participant Result as Transformed Function
PrimFunc->>Analyzer: Scan function body
Analyzer->>Analyzer: Classify each BufferLoad index<br/>(NonNegative/Negative/Unknown)
Analyzer-->>PrimFunc: Return state vectors per load
rect rgba(100, 150, 200, 0.2)
note over PrimFunc: If negative indices detected
PrimFunc->>Rewriter: Pass state vectors
Rewriter->>Rewriter: Rewrite BufferLoad nodes<br/>with index = extent + index
Rewriter-->>Result: Return new PrimFunc
end
rect rgba(150, 200, 100, 0.2)
note over PrimFunc: If no changes needed
PrimFunc-->>Result: Return original function
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Poem
Pre-merge checks and finishing touches✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_negative_index.py (1)
41-56: Add coverage for negative-index storesOnce the pass handles stores, please add a regression that writes through a negative index (e.g.
B[T.int32(-1)] = ...) and asserts the rewritten store usesextent + index. It will prevent future regressions around the BufferStore branch.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/transform/legalize_negative_index.cc(1 hunks)testing/python/language/test_tilelang_language_negative_index.py(1 hunks)tilelang/engine/phase.py(1 hunks)tilelang/transform/__init__.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
tilelang/transform/__init__.py (1)
src/transform/legalize_negative_index.cc (2)
LegalizeNegativeIndex(129-143)LegalizeNegativeIndex(129-129)
testing/python/language/test_tilelang_language_negative_index.py (3)
tilelang/language/ast/ir.py (1)
func_attr(217-225)src/transform/legalize_negative_index.cc (2)
LegalizeNegativeIndex(129-143)LegalizeNegativeIndex(129-129)tilelang/transform/__init__.py (1)
LegalizeNegativeIndex(83-91)
tilelang/engine/phase.py (2)
src/transform/legalize_negative_index.cc (2)
LegalizeNegativeIndex(129-143)LegalizeNegativeIndex(129-129)tilelang/transform/__init__.py (1)
LegalizeNegativeIndex(83-91)
src/transform/legalize_negative_index.cc (1)
tilelang/transform/__init__.py (1)
LegalizeNegativeIndex(83-91)
| class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer { | ||
| public: | ||
| explicit NegativeIndexAnalyzer( | ||
| std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>> | ||
| *result) | ||
| : result_(result) {} | ||
|
|
||
| void VisitExpr_(const BufferLoadNode *op) final { | ||
| auto load = tvm::ffi::GetRef<BufferLoad>(op); | ||
| std::vector<IndexSignState> states; | ||
| states.reserve(op->indices.size()); | ||
| bool needs_record = false; | ||
|
|
||
| for (size_t i = 0; i < op->indices.size(); ++i) { | ||
| PrimExpr simplified = analyzer_.Simplify(op->indices[i]); | ||
| if (analyzer_.CanProve(simplified >= 0)) { | ||
| states.push_back(IndexSignState::kNonNegative); | ||
| continue; | ||
| } | ||
|
|
||
| if (analyzer_.CanProve(simplified < 0)) { | ||
| states.push_back(IndexSignState::kNegative); | ||
| needs_record = true; | ||
| continue; | ||
| } | ||
|
|
||
| states.push_back(IndexSignState::kUnknown); | ||
| needs_record = true; | ||
| LOG(WARNING) << "LegalizeNegativeIndex: cannot prove non-negative index " | ||
| << simplified << " for buffer " << load->buffer->name | ||
| << " (axis " << i << ")."; | ||
| } | ||
|
|
||
| if (needs_record) { | ||
| (*result_)[op] = std::move(states); | ||
| } | ||
|
|
||
| IRVisitorWithAnalyzer::VisitExpr_(op); | ||
| } | ||
|
|
||
| private: | ||
| std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>> | ||
| *result_; | ||
| }; | ||
|
|
||
| class NegativeIndexRewriter : public arith::IRMutatorWithAnalyzer { | ||
| public: | ||
| static PrimFunc | ||
| Apply(PrimFunc func, | ||
| const std::unordered_map<const BufferLoadNode *, | ||
| std::vector<IndexSignState>> &states) { | ||
| arith::Analyzer analyzer; | ||
| NegativeIndexRewriter rewriter(&analyzer, states); | ||
| if (!func->body.defined()) { | ||
| return func; | ||
| } | ||
| PrimFuncNode *func_node = func.CopyOnWrite(); | ||
| func_node->body = rewriter.VisitStmt(func_node->body); | ||
| return func; | ||
| } | ||
|
|
||
| private: | ||
| NegativeIndexRewriter( | ||
| arith::Analyzer *analyzer, | ||
| const std::unordered_map<const BufferLoadNode *, | ||
| std::vector<IndexSignState>> &states) | ||
| : arith::IRMutatorWithAnalyzer(analyzer), states_(states) {} | ||
|
|
||
| PrimExpr VisitExpr_(const BufferLoadNode *op) final { | ||
| BufferLoad load = | ||
| Downcast<BufferLoad>(arith::IRMutatorWithAnalyzer::VisitExpr_(op)); | ||
|
|
||
| auto it = states_.find(op); | ||
| if (it == states_.end()) { | ||
| return load; | ||
| } | ||
|
|
||
| auto indices = load->indices; | ||
| bool changed = false; | ||
|
|
||
| const auto &state_vector = it->second; | ||
| ICHECK_EQ(state_vector.size(), indices.size()) | ||
| << "State vector size mismatch for buffer load " << load->buffer->name; | ||
|
|
||
| for (size_t i = 0; i < indices.size(); ++i) { | ||
| if (state_vector[i] != IndexSignState::kNegative) { | ||
| continue; | ||
| } | ||
| PrimExpr extent = load->buffer->shape[i]; | ||
| indices.Set(i, analyzer_->Simplify(extent + indices[i])); | ||
| changed = true; | ||
| } | ||
|
|
||
| if (!changed) { | ||
| return load; | ||
| } | ||
|
|
||
| return BufferLoad(load->buffer, indices); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Negative index BufferStore remains unhandled
The analyzer/rewriter only records BufferLoadNode indices, so a PrimFunc that writes through a negative index (e.g. A[T.int32(-1)] = value) survives this pass unchanged. Because LowerAndLegalize now runs this pass unconditionally, we still exit with negative indices on the store side, defeating the stated goal of canonicalizing indices before downstream passes.
Please extend the pass to cover BufferStoreNode as well (analyze its indices and rewrite them in VisitStmt_). A minimal shape for the fix:
- void VisitExpr_(const BufferLoadNode* op) final {
- ...
- (*result_)[op] = std::move(states);
- }
- IRVisitorWithAnalyzer::VisitExpr_(op);
- }
+ template <typename TNode>
+ void RecordIndices(const TNode* op, const Array<PrimExpr>& indices, const Buffer& buffer) {
+ ...
+ (*result_)[op] = std::move(states);
+ }
+ }
+
+ void VisitExpr_(const BufferLoadNode* op) final {
+ RecordIndices(op, op->indices, op->buffer);
+ IRVisitorWithAnalyzer::VisitExpr_(op);
+ }
+
+ void VisitStmt_(const BufferStoreNode* op) final {
+ RecordIndices(op, op->indices, op->buffer);
+ IRVisitorWithAnalyzer::VisitStmt_(op);
+ }and mirror the same lookup/rewrite logic in NegativeIndexRewriter::VisitStmt_(const BufferStoreNode*).
This keeps the pipeline guarantee intact and prevents negative store indices from leaking past legalization.
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In src/transform/legalize_negative_index.cc around lines 25-123, the pass only
analyzes/re-writes BufferLoadNode indices so negative indices in BufferStoreNode
survive; extend the pass to handle stores: add an analyzer VisitStmt_(const
BufferStoreNode*) that mirrors the BufferLoad logic (simplify each index,
classify as kNonNegative/kNegative/kUnknown, set needs_record and log warnings,
and record the per-index IndexSignState into a store-specific map), and add a
corresponding NegativeIndexRewriter::VisitStmt_(const BufferStoreNode*) that
looks up the recorded state vector for that store, asserts sizes match, rewrites
negative indices to extent + index (using analyzer_->Simplify) and returns a
modified BufferStore stmt when changed; update Apply and the pass state to
carry/store both load and store maps (or add a separate store map) so both
analysis and rewriting use the store information.
|
@codex review |
|
Codex Review: Didn't find any major issues. Delightful! ℹ️ About Codex in GitHubYour team has set up Codex to review pull requests in this repo. Reviews are triggered when you
If Codex has suggestions, it will comment; otherwise it will react with 👍. Codex can also answer questions or update the PR. Try commenting "@codex address that feedback". |
This pr add a
LegalizeNegativeIndexpass to handle negative index in BufferLoad operation.It tries to handle each dimension:
shape[i] + idx[i]Summary by CodeRabbit
Release Notes
New Features
Tests