Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions src/transform/legalize_negative_index.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*!
* \file legalize_negative_index.cc
* \brief Legalize negative indices in buffer load expressions.
*/

#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/logging.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include <unordered_map>
#include <vector>

#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"

namespace tvm {
namespace tl {

using namespace tir;
using arith::IRVisitorWithAnalyzer;

enum class IndexSignState { kNonNegative, kNegative, kUnknown };

class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer {
public:
explicit NegativeIndexAnalyzer(
std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>>
*result)
: result_(result) {}

void VisitExpr_(const BufferLoadNode *op) final {
auto load = tvm::ffi::GetRef<BufferLoad>(op);
std::vector<IndexSignState> states;
states.reserve(op->indices.size());
bool needs_record = false;

for (size_t i = 0; i < op->indices.size(); ++i) {
PrimExpr simplified = analyzer_.Simplify(op->indices[i]);
if (analyzer_.CanProve(simplified >= 0)) {
states.push_back(IndexSignState::kNonNegative);
continue;
}

if (analyzer_.CanProve(simplified < 0)) {
states.push_back(IndexSignState::kNegative);
needs_record = true;
continue;
}

states.push_back(IndexSignState::kUnknown);
needs_record = true;
LOG(WARNING) << "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << load->buffer->name
<< " (axis " << i << ").";
}

if (needs_record) {
(*result_)[op] = std::move(states);
}

IRVisitorWithAnalyzer::VisitExpr_(op);
}

private:
std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>>
*result_;
};

class NegativeIndexRewriter : public arith::IRMutatorWithAnalyzer {
public:
static PrimFunc
Apply(PrimFunc func,
const std::unordered_map<const BufferLoadNode *,
std::vector<IndexSignState>> &states) {
arith::Analyzer analyzer;
NegativeIndexRewriter rewriter(&analyzer, states);
if (!func->body.defined()) {
return func;
}
PrimFuncNode *func_node = func.CopyOnWrite();
func_node->body = rewriter.VisitStmt(func_node->body);
return func;
}

private:
NegativeIndexRewriter(
arith::Analyzer *analyzer,
const std::unordered_map<const BufferLoadNode *,
std::vector<IndexSignState>> &states)
: arith::IRMutatorWithAnalyzer(analyzer), states_(states) {}

PrimExpr VisitExpr_(const BufferLoadNode *op) final {
BufferLoad load =
Downcast<BufferLoad>(arith::IRMutatorWithAnalyzer::VisitExpr_(op));

auto it = states_.find(op);
if (it == states_.end()) {
return load;
}

auto indices = load->indices;
bool changed = false;

const auto &state_vector = it->second;
ICHECK_EQ(state_vector.size(), indices.size())
<< "State vector size mismatch for buffer load " << load->buffer->name;

for (size_t i = 0; i < indices.size(); ++i) {
if (state_vector[i] != IndexSignState::kNegative) {
continue;
}
PrimExpr extent = load->buffer->shape[i];
indices.Set(i, analyzer_->Simplify(extent + indices[i]));
changed = true;
}

if (!changed) {
return load;
}

return BufferLoad(load->buffer, indices);
}
Comment on lines +25 to +123
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Negative index BufferStore remains unhandled

The analyzer/rewriter only records BufferLoadNode indices, so a PrimFunc that writes through a negative index (e.g. A[T.int32(-1)] = value) survives this pass unchanged. Because LowerAndLegalize now runs this pass unconditionally, we still exit with negative indices on the store side, defeating the stated goal of canonicalizing indices before downstream passes.

Please extend the pass to cover BufferStoreNode as well (analyze its indices and rewrite them in VisitStmt_). A minimal shape for the fix:

-  void VisitExpr_(const BufferLoadNode* op) final {
-    ...
-      (*result_)[op] = std::move(states);
-    }
-    IRVisitorWithAnalyzer::VisitExpr_(op);
-  }
+  template <typename TNode>
+  void RecordIndices(const TNode* op, const Array<PrimExpr>& indices, const Buffer& buffer) {
+    ...
+      (*result_)[op] = std::move(states);
+    }
+  }
+
+  void VisitExpr_(const BufferLoadNode* op) final {
+    RecordIndices(op, op->indices, op->buffer);
+    IRVisitorWithAnalyzer::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const BufferStoreNode* op) final {
+    RecordIndices(op, op->indices, op->buffer);
+    IRVisitorWithAnalyzer::VisitStmt_(op);
+  }

and mirror the same lookup/rewrite logic in NegativeIndexRewriter::VisitStmt_(const BufferStoreNode*).

This keeps the pipeline guarantee intact and prevents negative store indices from leaking past legalization.

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/transform/legalize_negative_index.cc around lines 25-123, the pass only
analyzes/re-writes BufferLoadNode indices so negative indices in BufferStoreNode
survive; extend the pass to handle stores: add an analyzer VisitStmt_(const
BufferStoreNode*) that mirrors the BufferLoad logic (simplify each index,
classify as kNonNegative/kNegative/kUnknown, set needs_record and log warnings,
and record the per-index IndexSignState into a store-specific map), and add a
corresponding NegativeIndexRewriter::VisitStmt_(const BufferStoreNode*) that
looks up the recorded state vector for that store, asserts sizes match, rewrites
negative indices to extent + index (using analyzer_->Simplify) and returns a
modified BufferStore stmt when changed; update Apply and the pass state to
carry/store both load and store maps (or add a separate store map) so both
analysis and rewriting use the store information.


const std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>>
&states_;
};

PrimFunc LegalizeNegativeIndex(PrimFunc func) {
if (!func->body.defined()) {
return func;
}

std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>>
states;
NegativeIndexAnalyzer analyzer(&states);
analyzer(func->body);
if (states.empty()) {
return func;
}

return NegativeIndexRewriter::Apply(std::move(func), states);
}

tvm::transform::Pass LegalizeNegativeIndexPass() {
using namespace tir::transform;
auto pass_func = [](PrimFunc f, const IRModule &, PassContext) {
return LegalizeNegativeIndex(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tl.LegalizeNegativeIndex", {});
}

TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LegalizeNegativeIndex",
LegalizeNegativeIndexPass);
}

} // namespace tl
} // namespace tvm
60 changes: 60 additions & 0 deletions testing/python/language/test_tilelang_language_negative_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from tilelang import tvm
import tilelang as tl
import tilelang.testing
from tvm.script import tir as T


@T.prim_func
def negative_index_before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")):
T.func_attr({"tir.noalias": True})
B[0] = A[T.int32(-1)]


@T.prim_func
def negative_index_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")):
T.func_attr({"tir.noalias": True})
B[0] = A[T.int32(15)]


@T.prim_func
def negative_index_loop_before(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")):
T.func_attr({"tir.noalias": True})
for i in T.serial(4):
B[i] = A[-i - 1]


@T.prim_func
def negative_index_loop_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")):
T.func_attr({"tir.noalias": True})
for i in T.serial(4):
B[i] = A[15 - i]


@T.prim_func
def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), "float32"),
B: T.Buffer((16,), "float32")):
T.func_attr({"tir.noalias": True})
for i in T.serial(16):
B[i] = A[shift + i]


def test_legalize_negative_index_scalar():
mod = tvm.IRModule({"main": negative_index_before})
transformed = tl.transform.LegalizeNegativeIndex()(mod)
tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_expected.body)


def test_legalize_negative_index_affine_expr():
mod = tvm.IRModule({"main": negative_index_loop_before})
transformed = tl.transform.LegalizeNegativeIndex()(mod)
tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_loop_expected.body)


def test_legalize_negative_index_symbolic_passthrough():
mod = tvm.IRModule({"main": negative_index_symbolic_before})
transformed = tl.transform.LegalizeNegativeIndex()(mod)
tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_symbolic_before.body)


if __name__ == "__main__":
tilelang.testing.main()
2 changes: 2 additions & 0 deletions tilelang/engine/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.LetInline()(mod)
# Add wrapper for single buf store
mod = tilelang.transform.AddWrapperForSingleBufStore()(mod)
# Normalize negative indices to canonical non-negative form
mod = tilelang.transform.LegalizeNegativeIndex()(mod)
# Inject assumes to speedup tvm prover
mod = tilelang.transform.InjectAssumes()(mod)
# Simplify the IR expressions
Expand Down
11 changes: 11 additions & 0 deletions tilelang/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,17 @@ def FrontendLegalize():
return _ffi_api.FrontendLegalize() # type: ignore


def LegalizeNegativeIndex():
"""Legalize negative indices in buffer loads.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LegalizeNegativeIndex() # type: ignore


def InjectAssumes():
"""Inject Assumes

Expand Down
Loading