From 0cd84a15e76c90ad734efc6f1e953b1e54a7539f Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 23 Dec 2024 14:34:26 -0800 Subject: [PATCH] Use a consistent idiom for visit_let visit_let in the codebase uses a wide variety of template names, argument names, and ways of getting the body type. This just picks one and uses it consistently. No functional changes. --- src/BoundSmallAllocations.cpp | 18 ++++++------ src/ClampUnsafeAccesses.cpp | 18 ++++++------ src/Deinterleave.cpp | 28 +++++++++--------- src/EliminateBoolVectors.cpp | 10 +++---- src/FuseGPUThreadLoops.cpp | 10 +++---- src/HexagonOptimize.cpp | 42 +++++++++++++-------------- src/LICM.cpp | 36 +++++++++++------------ src/LowerWarpShuffles.cpp | 10 +++---- src/OptimizeShuffles.cpp | 12 ++++---- src/RemoveUndef.cpp | 18 ++++++------ src/SimplifyCorrelatedDifferences.cpp | 10 +++---- src/TrimNoOps.cpp | 10 +++---- src/UniquifyVariableNames.cpp | 6 ++-- src/VectorizeLoops.cpp | 12 ++++---- 14 files changed, 120 insertions(+), 120 deletions(-) diff --git a/src/BoundSmallAllocations.cpp b/src/BoundSmallAllocations.cpp index 849d224051b4..80f58889448c 100644 --- a/src/BoundSmallAllocations.cpp +++ b/src/BoundSmallAllocations.cpp @@ -17,40 +17,40 @@ class BoundSmallAllocations : public IRMutator { // Track constant bounds Scope scope; - template - Body visit_let(const T *op) { + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { // Visit an entire chain of lets in a single method to conserve stack space. struct Frame { - const T *op; + const LetOrLetStmt *op; ScopedBinding binding; - Frame(const T *op, Scope &scope) + Frame(const LetOrLetStmt *op, Scope &scope) : op(op), binding(scope, op->name, find_constant_bounds(op->value, scope)) { } }; std::vector frames; - Body result; + decltype(op->body) result; do { result = op->body; frames.emplace_back(op, scope); - } while ((op = result.template as())); + } while ((op = result.template as())); result = mutate(result); for (const auto &frame : reverse_view(frames)) { - result = T::make(frame.op->name, frame.op->value, result); + result = LetOrLetStmt::make(frame.op->name, frame.op->value, result); } return result; } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } Expr visit(const Let *op) override { - return visit_let(op); + return visit_let(op); } bool in_thread_loop = false; diff --git a/src/ClampUnsafeAccesses.cpp b/src/ClampUnsafeAccesses.cpp index b3dd9ddc235e..ed6955446196 100644 --- a/src/ClampUnsafeAccesses.cpp +++ b/src/ClampUnsafeAccesses.cpp @@ -42,11 +42,11 @@ struct ClampUnsafeAccesses : IRMutator { } Expr visit(const Let *let) override { - return visit_let(let); + return visit_let(let); } Stmt visit(const LetStmt *let) override { - return visit_let(let); + return visit_let(let); } Expr visit(const Variable *var) override { @@ -80,15 +80,15 @@ struct ClampUnsafeAccesses : IRMutator { } private: - template - Body visit_let(const L *let) { - ScopedBinding binding(let_var_inside_indexing, let->name, false); - Body body = mutate(let->body); + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { + ScopedBinding binding(let_var_inside_indexing, op->name, false); + auto body = mutate(op->body); - ScopedValue s(is_inside_indexing, is_inside_indexing || let_var_inside_indexing.get(let->name)); - Expr value = mutate(let->value); + ScopedValue s(is_inside_indexing, is_inside_indexing || let_var_inside_indexing.get(op->name)); + Expr value = mutate(op->value); - return L::make(let->name, std::move(value), std::move(body)); + return LetOrLetStmt::make(op->name, std::move(value), std::move(body)); } bool bounds_smaller_than_type(const Interval &bounds, Type type) { diff --git a/src/Deinterleave.cpp b/src/Deinterleave.cpp index 54199485fed5..6cd4f490f01c 100644 --- a/src/Deinterleave.cpp +++ b/src/Deinterleave.cpp @@ -465,44 +465,44 @@ class Interleaver : public IRMutator { return Shuffle::make_interleave(exprs); } - template - Body visit_lets(const T *op) { + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { // Visit an entire chain of lets in a single method to conserve stack space. struct Frame { - const T *op; + const LetOrLetStmt *op; Expr new_value; ScopedBinding<> binding; - Frame(const T *op, Expr v, Scope &scope) + Frame(const LetOrLetStmt *op, Expr v, Scope &scope) : op(op), new_value(std::move(v)), binding(new_value.type().is_vector(), scope, op->name) { } }; std::vector frames; - Body result; + decltype(op->body) result; do { result = op->body; frames.emplace_back(op, mutate(op->value), vector_lets); - } while ((op = result.template as())); + } while ((op = result.template as())); result = mutate(result); for (const auto &frame : reverse_view(frames)) { Expr value = std::move(frame.new_value); - result = T::make(frame.op->name, value, result); + result = LetOrLetStmt::make(frame.op->name, value, result); // For vector lets, we may additionally need a let defining the even and odd lanes only if (value.type().is_vector()) { if (value.type().lanes() % 2 == 0) { - result = T::make(frame.op->name + ".even_lanes", extract_even_lanes(value, vector_lets), result); - result = T::make(frame.op->name + ".odd_lanes", extract_odd_lanes(value, vector_lets), result); + result = LetOrLetStmt::make(frame.op->name + ".even_lanes", extract_even_lanes(value, vector_lets), result); + result = LetOrLetStmt::make(frame.op->name + ".odd_lanes", extract_odd_lanes(value, vector_lets), result); } if (value.type().lanes() % 3 == 0) { - result = T::make(frame.op->name + ".lanes_0_of_3", extract_mod3_lanes(value, 0, vector_lets), result); - result = T::make(frame.op->name + ".lanes_1_of_3", extract_mod3_lanes(value, 1, vector_lets), result); - result = T::make(frame.op->name + ".lanes_2_of_3", extract_mod3_lanes(value, 2, vector_lets), result); + result = LetOrLetStmt::make(frame.op->name + ".lanes_0_of_3", extract_mod3_lanes(value, 0, vector_lets), result); + result = LetOrLetStmt::make(frame.op->name + ".lanes_1_of_3", extract_mod3_lanes(value, 1, vector_lets), result); + result = LetOrLetStmt::make(frame.op->name + ".lanes_2_of_3", extract_mod3_lanes(value, 2, vector_lets), result); } } } @@ -511,11 +511,11 @@ class Interleaver : public IRMutator { } Expr visit(const Let *op) override { - return visit_lets(op); + return visit_let(op); } Stmt visit(const LetStmt *op) override { - return visit_lets(op); + return visit_let(op); } Expr visit(const Ramp *op) override { diff --git a/src/EliminateBoolVectors.cpp b/src/EliminateBoolVectors.cpp index 2517ecf70c1c..e4afa8f21569 100644 --- a/src/EliminateBoolVectors.cpp +++ b/src/EliminateBoolVectors.cpp @@ -287,8 +287,8 @@ class EliminateBoolVectors : public IRMutator { return expr; } - template - NodeType visit_let(const LetType *op) { + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { Expr value = mutate(op->value); // We changed the type of the let, we need to replace the @@ -305,17 +305,17 @@ class EliminateBoolVectors : public IRMutator { } if (!value.same_as(op->value) || !body.same_as(op->body)) { - return LetType::make(op->name, value, body); + return LetOrLetStmt::make(op->name, value, body); } else { return op; } } Expr visit(const Let *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } }; diff --git a/src/FuseGPUThreadLoops.cpp b/src/FuseGPUThreadLoops.cpp index 3464a91ddb58..b7da1555526a 100644 --- a/src/FuseGPUThreadLoops.cpp +++ b/src/FuseGPUThreadLoops.cpp @@ -1157,9 +1157,9 @@ class ExtractRegisterAllocations : public IRMutator { op->param, mutate(op->predicate), op->alignment); } - template - ExprOrStmt visit_let(const LetOrLetStmt *op) { - ExprOrStmt body = op->body; + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { + auto body = op->body; body = mutate(op->body); Expr value = mutate(op->value); @@ -1178,11 +1178,11 @@ class ExtractRegisterAllocations : public IRMutator { } Expr visit(const Let *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } Scope register_allocations; diff --git a/src/HexagonOptimize.cpp b/src/HexagonOptimize.cpp index d75231813215..a88f0e6bad57 100644 --- a/src/HexagonOptimize.cpp +++ b/src/HexagonOptimize.cpp @@ -1088,20 +1088,20 @@ class OptimizePatterns : public IRMutator { } } - template - NodeType visit_let(const T *op) { + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { bounds.push(op->name, bounds_of_expr_in_scope(op->value, bounds)); - NodeType node = IRMutator::visit(op); + auto node = IRMutator::visit(op); bounds.pop(op->name); return node; } Expr visit(const Let *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } Expr visit(const Div *op) override { @@ -1599,12 +1599,12 @@ class EliminateInterleaves : public IRMutator { } } - template - NodeType visit_let(const LetType *op) { + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { Expr value = mutate(op->value); string deinterleaved_name; - NodeType body; + decltype(op->body) body; // Other code in this mutator needs to be able to tell the // difference between a Let that yields a deinterleave, and a // let that has a removable deinterleave. Lets that can @@ -1632,10 +1632,10 @@ class EliminateInterleaves : public IRMutator { return op; } else if (body.same_as(op->body)) { // If the body didn't change, we must not have used the deinterleaved value. - return LetType::make(op->name, value, body); + return LetOrLetStmt::make(op->name, value, body); } else { // We need to rewrap the body with new lets. - NodeType result = body; + auto result = body; bool deinterleaved_used = stmt_or_expr_uses_var(result, deinterleaved_name); bool interleaved_used = stmt_or_expr_uses_var(result, op->name); if (deinterleaved_used && interleaved_used) { @@ -1653,14 +1653,14 @@ class EliminateInterleaves : public IRMutator { interleaved = native_interleave(interleaved); } - result = LetType::make(op->name, interleaved, result); - return LetType::make(deinterleaved_name, deinterleaved, result); + result = LetOrLetStmt::make(op->name, interleaved, result); + return LetOrLetStmt::make(deinterleaved_name, deinterleaved, result); } else if (deinterleaved_used) { // Only the deinterleaved value is used, we can eliminate the interleave. - return LetType::make(deinterleaved_name, remove_interleave(value), result); + return LetOrLetStmt::make(deinterleaved_name, remove_interleave(value), result); } else if (interleaved_used) { // Only the original value is used, regenerate the let. - return LetType::make(op->name, value, result); + return LetOrLetStmt::make(op->name, value, result); } else { // The let must have been dead. internal_assert(!stmt_or_expr_uses_var(op->body, op->name)) @@ -1671,7 +1671,7 @@ class EliminateInterleaves : public IRMutator { } Expr visit(const Let *op) override { - Expr expr = visit_let(op); + Expr expr = visit_let(op); // Lift interleaves out of Let expression bodies. const Let *let = expr.as(); @@ -1682,7 +1682,7 @@ class EliminateInterleaves : public IRMutator { } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } Expr visit(const Cast *op) override { @@ -2047,13 +2047,13 @@ class ScatterGatherGenerator : public IRMutator { return IRMutator::visit(op); } - template - NodeType visit_let(const T *op) { + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { // We only care about vector lets. if (op->value.type().is_vector()) { bounds.push(op->name, bounds_of_expr_in_scope(op->value, bounds)); } - NodeType node = IRMutator::visit(op); + auto node = IRMutator::visit(op); if (op->value.type().is_vector()) { bounds.pop(op->name); } @@ -2061,11 +2061,11 @@ class ScatterGatherGenerator : public IRMutator { } Expr visit(const Let *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const Allocate *op) override { diff --git a/src/LICM.cpp b/src/LICM.cpp index 880354d89582..c73fcf5424ab 100644 --- a/src/LICM.cpp +++ b/src/LICM.cpp @@ -112,23 +112,23 @@ class LiftLoopInvariants : public IRMutator { return true; } - template - Body visit_let(const T *op) { + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { // Visit an entire chain of lets in a single method to conserve stack space. struct Frame { - const T *op; + const LetOrLetStmt *op; Expr new_value; ScopedBinding<> binding; - Frame(const T *op, Expr v, Scope<> &scope) + Frame(const LetOrLetStmt *op, Expr v, Scope<> &scope) : op(op), new_value(std::move(v)), binding(scope, op->name) { } }; vector frames; - Body result; + decltype(op->body) result; do { frames.emplace_back(op, mutate(op->value), varying); result = op->body; - } while ((op = result.template as())); + } while ((op = result.template as())); result = mutate(result); @@ -136,7 +136,7 @@ class LiftLoopInvariants : public IRMutator { if (frame.new_value.same_as(frame.op->value) && result.same_as(frame.op->body)) { result = frame.op; } else { - result = T::make(frame.op->name, std::move(frame.new_value), result); + result = LetOrLetStmt::make(frame.op->name, std::move(frame.new_value), result); } } @@ -144,11 +144,11 @@ class LiftLoopInvariants : public IRMutator { } Expr visit(const Let *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const For *op) override { @@ -476,20 +476,20 @@ class GroupLoopInvariants : public IRMutator { return stmt; } - template - Body visit_let(const T *op) { + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { struct Frame { - const T *op; + const LetOrLetStmt *op; Expr new_value; ScopedBinding binding; - Frame(const T *op, Expr v, int depth, Scope &scope) + Frame(const LetOrLetStmt *op, Expr v, int depth, Scope &scope) : op(op), new_value(std::move(v)), binding(scope, op->name, depth) { } }; std::vector frames; - Body result; + decltype(op->body) result; do { result = op->body; @@ -498,7 +498,7 @@ class GroupLoopInvariants : public IRMutator { d = expr_depth(op->value); } frames.emplace_back(op, mutate(op->value), d, var_depth); - } while ((op = result.template as())); + } while ((op = result.template as())); result = mutate(result); @@ -506,7 +506,7 @@ class GroupLoopInvariants : public IRMutator { if (frame.new_value.same_as(frame.op->value) && result.same_as(frame.op->body)) { result = frame.op; } else { - result = T::make(frame.op->name, frame.new_value, result); + result = LetOrLetStmt::make(frame.op->name, frame.new_value, result); } } @@ -514,11 +514,11 @@ class GroupLoopInvariants : public IRMutator { } Expr visit(const Let *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } }; diff --git a/src/LowerWarpShuffles.cpp b/src/LowerWarpShuffles.cpp index fa3dfecb9a1f..2551fe0bffbb 100644 --- a/src/LowerWarpShuffles.cpp +++ b/src/LowerWarpShuffles.cpp @@ -691,10 +691,10 @@ class HoistWarpShufflesFromSingleIfStmt : public IRMutator { } } - template - ExprOrStmt visit_let(const LetOrLetStmt *op) { + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { Expr value = mutate(op->value); - ExprOrStmt body = mutate(op->body); + auto body = mutate(op->body); // If any of the lifted expressions use this, we also need to // lift this. @@ -712,11 +712,11 @@ class HoistWarpShufflesFromSingleIfStmt : public IRMutator { } Expr visit(const Let *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const For *op) override { diff --git a/src/OptimizeShuffles.cpp b/src/OptimizeShuffles.cpp index c986af62474e..0a88d02f0b60 100644 --- a/src/OptimizeShuffles.cpp +++ b/src/OptimizeShuffles.cpp @@ -38,13 +38,13 @@ class OptimizeShuffles : public IRMutator { return IRMutator::visit(op); } - template - NodeType visit_let(const T *op) { + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { // We only care about vector lets. if (op->value.type().is_vector()) { bounds.push(op->name, bounds_of_expr_in_scope(op->value, bounds)); } - NodeType node = IRMutator::visit(op); + auto node = IRMutator::visit(op); if (op->value.type().is_vector()) { bounds.pop(op->name); } @@ -53,12 +53,12 @@ class OptimizeShuffles : public IRMutator { Expr visit(const Let *op) override { lets.emplace_back(op->name, op->value); - Expr expr = visit_let(op); + Expr expr = visit_let(op); lets.pop_back(); return expr; } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } std::set allocations_to_pad; @@ -149,4 +149,4 @@ Stmt optimize_shuffles(Stmt s, int lut_alignment) { } } // namespace Internal -} // namespace Halide \ No newline at end of file +} // namespace Halide diff --git a/src/RemoveUndef.cpp b/src/RemoveUndef.cpp index 9103e401fb33..c738b23311aa 100644 --- a/src/RemoveUndef.cpp +++ b/src/RemoveUndef.cpp @@ -241,25 +241,25 @@ class RemoveUndef : public IRMutator { } } - template - Body visit_let(const T *op) { + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { // Visit an entire chain of lets in a single method to conserve stack space. struct Frame { - const T *op; + const LetOrLetStmt *op; Expr new_value; ScopedBinding<> binding; - Frame(const T *op, Expr v, Scope<> &scope) + Frame(const LetOrLetStmt *op, Expr v, Scope<> &scope) : op(op), new_value(std::move(v)), binding(!new_value.defined(), scope, op->name) { } }; vector frames; - Body result; + decltype(op->body) result; do { frames.emplace_back(op, mutate(op->value), dead_vars); result = op->body; - } while ((op = result.template as())); + } while ((op = result.template as())); result = mutate(result); @@ -272,7 +272,7 @@ class RemoveUndef : public IRMutator { if (frame.new_value.same_as(frame.op->value) && result.same_as(frame.op->body)) { result = frame.op; } else { - result = T::make(frame.op->name, std::move(frame.new_value), result); + result = LetOrLetStmt::make(frame.op->name, std::move(frame.new_value), result); } } } @@ -281,11 +281,11 @@ class RemoveUndef : public IRMutator { } Expr visit(const Let *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const AssertStmt *op) override { diff --git a/src/SimplifyCorrelatedDifferences.cpp b/src/SimplifyCorrelatedDifferences.cpp index 7c6a3bac735c..ba9c2d772fff 100644 --- a/src/SimplifyCorrelatedDifferences.cpp +++ b/src/SimplifyCorrelatedDifferences.cpp @@ -78,8 +78,8 @@ class SimplifyCorrelatedDifferences : public IRMutator { }; vector lets; - template - StmtOrExpr visit_let(const LetStmtOrLet *op) { + template + auto visit_let(const LetStmtOrLet *op) -> decltype(op->body) { // Visit an entire chain of lets in a single method to conserve stack space. struct Frame { const LetStmtOrLet *op; @@ -94,7 +94,7 @@ class SimplifyCorrelatedDifferences : public IRMutator { } }; std::vector frames; - StmtOrExpr result; + decltype(op->body) result; // Note that we must add *everything* that depends on the loop // var to the monotonic scope and the list of lets, even @@ -146,11 +146,11 @@ class SimplifyCorrelatedDifferences : public IRMutator { } Expr visit(const Let *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const For *op) override { diff --git a/src/TrimNoOps.cpp b/src/TrimNoOps.cpp index 1d7232a89b11..e516a17b12bd 100644 --- a/src/TrimNoOps.cpp +++ b/src/TrimNoOps.cpp @@ -309,10 +309,10 @@ class SimplifyUsingBounds : public IRMutator { return visit_cmp(op); } - template - StmtOrExpr visit_let(const LetStmtOrLet *op) { + template + auto visit_let(const LetStmtOrLet *op) -> decltype(op->body) { Expr value = mutate(op->value); - StmtOrExpr body; + decltype(op->body) body; if (value.type() == Int(32) && is_pure(value)) { containing_loops.push_back({op->name, {value, value}}); body = mutate(op->body); @@ -324,11 +324,11 @@ class SimplifyUsingBounds : public IRMutator { } Expr visit(const Let *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const For *op) override { diff --git a/src/UniquifyVariableNames.cpp b/src/UniquifyVariableNames.cpp index 2d89e77de787..9dc92c780b3a 100644 --- a/src/UniquifyVariableNames.cpp +++ b/src/UniquifyVariableNames.cpp @@ -130,15 +130,15 @@ class FindFreeVars : public IRVisitor { } } - template - void visit_let(const T *op) { + template + void visit_let(const LetOrLetStmt *op) { vector> frame; decltype(op->body) body; do { op->value.accept(this); frame.emplace_back(scope, op->name); body = op->body; - op = body.template as(); + op = body.template as(); } while (op); body.accept(this); } diff --git a/src/VectorizeLoops.cpp b/src/VectorizeLoops.cpp index 7b16083ca5e2..0c7e762ad262 100644 --- a/src/VectorizeLoops.cpp +++ b/src/VectorizeLoops.cpp @@ -1411,8 +1411,8 @@ class FindVectorizableExprsInAtomicNode : public IRMutator { using IRMutator::visit; - template - const T *visit_let(const T *op) { + template + const LetOrLetStmt *visit_let(const LetOrLetStmt *op) { mutate(op->value); ScopedBinding<> bind_if(poison, poisoned_names, op->name); mutate(op->body); @@ -1498,8 +1498,8 @@ class LiftVectorizableExprsOutOfSingleAtomicNode : public IRMutator { using IRMutator::visit; - template - StmtOrExpr visit_let(const LetStmtOrLet *op) { + template + auto visit_let(const LetStmtOrLet *op) -> decltype(op->body) { if (liftable.count(op->value)) { // Lift it under its current name to avoid having to // rewrite the variables in other lifted exprs. @@ -1512,11 +1512,11 @@ class LiftVectorizableExprsOutOfSingleAtomicNode : public IRMutator { } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } Expr visit(const Let *op) override { - return visit_let(op); + return visit_let(op); } public: