Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a consistent idiom for visit_let #8540

Merged
merged 1 commit into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/BoundSmallAllocations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,40 +17,40 @@ class BoundSmallAllocations : public IRMutator {
// Track constant bounds
Scope<Interval> scope;

template<typename T, typename Body>
Body visit_let(const T *op) {
template<typename LetOrLetStmt>
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
// Visit an entire chain of lets in a single method to conserve stack space.
struct Frame {
const T *op;
const LetOrLetStmt *op;
ScopedBinding<Interval> binding;
Frame(const T *op, Scope<Interval> &scope)
Frame(const LetOrLetStmt *op, Scope<Interval> &scope)
: op(op),
binding(scope, op->name, find_constant_bounds(op->value, scope)) {
}
};
std::vector<Frame> frames;
Body result;
decltype(op->body) result;

do {
result = op->body;
frames.emplace_back(op, scope);
} while ((op = result.template as<T>()));
} while ((op = result.template as<LetOrLetStmt>()));

result = mutate(result);

for (const auto &frame : reverse_view(frames)) {
result = T::make(frame.op->name, frame.op->value, result);
result = LetOrLetStmt::make(frame.op->name, frame.op->value, result);
}

return result;
}

Stmt visit(const LetStmt *op) override {
return visit_let<LetStmt, Stmt>(op);
return visit_let(op);
}

Expr visit(const Let *op) override {
return visit_let<Let, Expr>(op);
return visit_let(op);
}

bool in_thread_loop = false;
Expand Down
18 changes: 9 additions & 9 deletions src/ClampUnsafeAccesses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ struct ClampUnsafeAccesses : IRMutator {
}

Expr visit(const Let *let) override {
return visit_let<Let, Expr>(let);
return visit_let(let);
}

Stmt visit(const LetStmt *let) override {
return visit_let<LetStmt, Stmt>(let);
return visit_let(let);
}

Expr visit(const Variable *var) override {
Expand Down Expand Up @@ -80,15 +80,15 @@ struct ClampUnsafeAccesses : IRMutator {
}

private:
template<typename L, typename Body>
Body visit_let(const L *let) {
ScopedBinding<bool> binding(let_var_inside_indexing, let->name, false);
Body body = mutate(let->body);
template<typename LetOrLetStmt>
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
ScopedBinding<bool> binding(let_var_inside_indexing, op->name, false);
auto body = mutate(op->body);

ScopedValue<bool> s(is_inside_indexing, is_inside_indexing || let_var_inside_indexing.get(let->name));
Expr value = mutate(let->value);
ScopedValue<bool> s(is_inside_indexing, is_inside_indexing || let_var_inside_indexing.get(op->name));
Expr value = mutate(op->value);

return L::make(let->name, std::move(value), std::move(body));
return LetOrLetStmt::make(op->name, std::move(value), std::move(body));
}

bool bounds_smaller_than_type(const Interval &bounds, Type type) {
Expand Down
28 changes: 14 additions & 14 deletions src/Deinterleave.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,44 +465,44 @@ class Interleaver : public IRMutator {
return Shuffle::make_interleave(exprs);
}

template<typename T, typename Body>
Body visit_lets(const T *op) {
template<typename LetOrLetStmt>
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
// Visit an entire chain of lets in a single method to conserve stack space.
struct Frame {
const T *op;
const LetOrLetStmt *op;
Expr new_value;
ScopedBinding<> binding;
Frame(const T *op, Expr v, Scope<void> &scope)
Frame(const LetOrLetStmt *op, Expr v, Scope<void> &scope)
: op(op),
new_value(std::move(v)),
binding(new_value.type().is_vector(), scope, op->name) {
}
};
std::vector<Frame> frames;
Body result;
decltype(op->body) result;

do {
result = op->body;
frames.emplace_back(op, mutate(op->value), vector_lets);
} while ((op = result.template as<T>()));
} while ((op = result.template as<LetOrLetStmt>()));

result = mutate(result);

for (const auto &frame : reverse_view(frames)) {
Expr value = std::move(frame.new_value);

result = T::make(frame.op->name, value, result);
result = LetOrLetStmt::make(frame.op->name, value, result);

// For vector lets, we may additionally need a let defining the even and odd lanes only
if (value.type().is_vector()) {
if (value.type().lanes() % 2 == 0) {
result = T::make(frame.op->name + ".even_lanes", extract_even_lanes(value, vector_lets), result);
result = T::make(frame.op->name + ".odd_lanes", extract_odd_lanes(value, vector_lets), result);
result = LetOrLetStmt::make(frame.op->name + ".even_lanes", extract_even_lanes(value, vector_lets), result);
result = LetOrLetStmt::make(frame.op->name + ".odd_lanes", extract_odd_lanes(value, vector_lets), result);
}
if (value.type().lanes() % 3 == 0) {
result = T::make(frame.op->name + ".lanes_0_of_3", extract_mod3_lanes(value, 0, vector_lets), result);
result = T::make(frame.op->name + ".lanes_1_of_3", extract_mod3_lanes(value, 1, vector_lets), result);
result = T::make(frame.op->name + ".lanes_2_of_3", extract_mod3_lanes(value, 2, vector_lets), result);
result = LetOrLetStmt::make(frame.op->name + ".lanes_0_of_3", extract_mod3_lanes(value, 0, vector_lets), result);
result = LetOrLetStmt::make(frame.op->name + ".lanes_1_of_3", extract_mod3_lanes(value, 1, vector_lets), result);
result = LetOrLetStmt::make(frame.op->name + ".lanes_2_of_3", extract_mod3_lanes(value, 2, vector_lets), result);
}
}
}
Expand All @@ -511,11 +511,11 @@ class Interleaver : public IRMutator {
}

Expr visit(const Let *op) override {
return visit_lets<Let, Expr>(op);
return visit_let(op);
}

Stmt visit(const LetStmt *op) override {
return visit_lets<LetStmt, Stmt>(op);
return visit_let(op);
}

Expr visit(const Ramp *op) override {
Expand Down
10 changes: 5 additions & 5 deletions src/EliminateBoolVectors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,8 @@ class EliminateBoolVectors : public IRMutator {
return expr;
}

template<typename NodeType, typename LetType>
NodeType visit_let(const LetType *op) {
template<typename LetOrLetStmt>
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
Expr value = mutate(op->value);

// We changed the type of the let, we need to replace the
Expand All @@ -305,17 +305,17 @@ class EliminateBoolVectors : public IRMutator {
}

if (!value.same_as(op->value) || !body.same_as(op->body)) {
return LetType::make(op->name, value, body);
return LetOrLetStmt::make(op->name, value, body);
} else {
return op;
}
}

Expr visit(const Let *op) override {
return visit_let<Expr>(op);
return visit_let(op);
}
Stmt visit(const LetStmt *op) override {
return visit_let<Stmt>(op);
return visit_let(op);
}
};

Expand Down
10 changes: 5 additions & 5 deletions src/FuseGPUThreadLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1157,9 +1157,9 @@ class ExtractRegisterAllocations : public IRMutator {
op->param, mutate(op->predicate), op->alignment);
}

template<typename ExprOrStmt, typename LetOrLetStmt>
ExprOrStmt visit_let(const LetOrLetStmt *op) {
ExprOrStmt body = op->body;
template<typename LetOrLetStmt>
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
auto body = op->body;

body = mutate(op->body);
Expr value = mutate(op->value);
Expand All @@ -1178,11 +1178,11 @@ class ExtractRegisterAllocations : public IRMutator {
}

Expr visit(const Let *op) override {
return visit_let<Expr>(op);
return visit_let(op);
}

Stmt visit(const LetStmt *op) override {
return visit_let<Stmt>(op);
return visit_let(op);
}

Scope<int> register_allocations;
Expand Down
42 changes: 21 additions & 21 deletions src/HexagonOptimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1088,20 +1088,20 @@ class OptimizePatterns : public IRMutator {
}
}

template<typename NodeType, typename T>
NodeType visit_let(const T *op) {
template<typename LetOrLetStmt>
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
bounds.push(op->name, bounds_of_expr_in_scope(op->value, bounds));
NodeType node = IRMutator::visit(op);
auto node = IRMutator::visit(op);
bounds.pop(op->name);
return node;
}

Expr visit(const Let *op) override {
return visit_let<Expr>(op);
return visit_let(op);
}

Stmt visit(const LetStmt *op) override {
return visit_let<Stmt>(op);
return visit_let(op);
}

Expr visit(const Div *op) override {
Expand Down Expand Up @@ -1599,12 +1599,12 @@ class EliminateInterleaves : public IRMutator {
}
}

template<typename NodeType, typename LetType>
NodeType visit_let(const LetType *op) {
template<typename LetOrLetStmt>
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {

Expr value = mutate(op->value);
string deinterleaved_name;
NodeType body;
decltype(op->body) body;
// Other code in this mutator needs to be able to tell the
// difference between a Let that yields a deinterleave, and a
// let that has a removable deinterleave. Lets that can
Expand Down Expand Up @@ -1632,10 +1632,10 @@ class EliminateInterleaves : public IRMutator {
return op;
} else if (body.same_as(op->body)) {
// If the body didn't change, we must not have used the deinterleaved value.
return LetType::make(op->name, value, body);
return LetOrLetStmt::make(op->name, value, body);
} else {
// We need to rewrap the body with new lets.
NodeType result = body;
auto result = body;
bool deinterleaved_used = stmt_or_expr_uses_var(result, deinterleaved_name);
bool interleaved_used = stmt_or_expr_uses_var(result, op->name);
if (deinterleaved_used && interleaved_used) {
Expand All @@ -1653,14 +1653,14 @@ class EliminateInterleaves : public IRMutator {
interleaved = native_interleave(interleaved);
}

result = LetType::make(op->name, interleaved, result);
return LetType::make(deinterleaved_name, deinterleaved, result);
result = LetOrLetStmt::make(op->name, interleaved, result);
return LetOrLetStmt::make(deinterleaved_name, deinterleaved, result);
} else if (deinterleaved_used) {
// Only the deinterleaved value is used, we can eliminate the interleave.
return LetType::make(deinterleaved_name, remove_interleave(value), result);
return LetOrLetStmt::make(deinterleaved_name, remove_interleave(value), result);
} else if (interleaved_used) {
// Only the original value is used, regenerate the let.
return LetType::make(op->name, value, result);
return LetOrLetStmt::make(op->name, value, result);
} else {
// The let must have been dead.
internal_assert(!stmt_or_expr_uses_var(op->body, op->name))
Expand All @@ -1671,7 +1671,7 @@ class EliminateInterleaves : public IRMutator {
}

Expr visit(const Let *op) override {
Expr expr = visit_let<Expr>(op);
Expr expr = visit_let(op);

// Lift interleaves out of Let expression bodies.
const Let *let = expr.as<Let>();
Expand All @@ -1682,7 +1682,7 @@ class EliminateInterleaves : public IRMutator {
}

Stmt visit(const LetStmt *op) override {
return visit_let<Stmt>(op);
return visit_let(op);
}

Expr visit(const Cast *op) override {
Expand Down Expand Up @@ -2047,25 +2047,25 @@ class ScatterGatherGenerator : public IRMutator {
return IRMutator::visit(op);
}

template<typename NodeType, typename T>
NodeType visit_let(const T *op) {
template<typename LetOrLetStmt>
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
// We only care about vector lets.
if (op->value.type().is_vector()) {
bounds.push(op->name, bounds_of_expr_in_scope(op->value, bounds));
}
NodeType node = IRMutator::visit(op);
auto node = IRMutator::visit(op);
if (op->value.type().is_vector()) {
bounds.pop(op->name);
}
return node;
}

Expr visit(const Let *op) override {
return visit_let<Expr>(op);
return visit_let(op);
}

Stmt visit(const LetStmt *op) override {
return visit_let<Stmt>(op);
return visit_let(op);
}

Stmt visit(const Allocate *op) override {
Expand Down
Loading