Skip to content

Commit

Permalink
Backport reverse_view to clean up some code (#8486)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexreinking authored Nov 23, 2024
1 parent b4b9911 commit 922e469
Show file tree
Hide file tree
Showing 41 changed files with 431 additions and 435 deletions.
25 changes: 12 additions & 13 deletions src/AsyncProducers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,20 +543,20 @@ class InitializeSemaphores : public IRMutator {
body = LetStmt::make(op->name, std::move(sema_allocate), std::move(body));

// Re-wrap any other lets
for (auto it = lets.rbegin(); it != lets.rend(); it++) {
body = LetStmt::make(it->first, it->second, std::move(body));
for (const auto &[var, value] : reverse_view(lets)) {
body = LetStmt::make(var, value, std::move(body));
}
}
} else {
body = mutate(frames.back()->body);
}

for (auto it = frames.rbegin(); it != frames.rend(); it++) {
Expr value = mutate((*it)->value);
if (value.same_as((*it)->value) && body.same_as((*it)->body)) {
body = *it;
for (const auto *frame : reverse_view(frames)) {
Expr value = mutate(frame->value);
if (value.same_as(frame->value) && body.same_as(frame->body)) {
body = frame;
} else {
body = LetStmt::make((*it)->name, std::move(value), std::move(body));
body = LetStmt::make(frame->name, std::move(value), std::move(body));
}
}
return body;
Expand Down Expand Up @@ -654,8 +654,8 @@ class TightenProducerConsumerNodes : public IRMutator {
body = make_producer_consumer(name, is_producer, body, scope, uses_vars);
}

for (auto it = containing_lets.rbegin(); it != containing_lets.rend(); it++) {
body = LetStmt::make((*it)->name, (*it)->value, body);
for (const auto *container : reverse_view(containing_lets)) {
body = LetStmt::make(container->name, container->value, body);
}

return body;
Expand Down Expand Up @@ -846,8 +846,7 @@ class ExpandAcquireNodes : public IRMutator {
result = mutate(result);

vector<pair<Expr, Expr>> semaphores;
for (auto it = stmts.rbegin(); it != stmts.rend(); it++) {
Stmt s = *it;
for (Stmt s : reverse_view(stmts)) {
while (const Acquire *a = s.as<Acquire>()) {
semaphores.emplace_back(a->semaphore, a->count);
s = a->body;
Expand Down Expand Up @@ -916,8 +915,8 @@ class ExpandAcquireNodes : public IRMutator {
}

// Rewrap the rest of the lets
for (auto it = frames.rbegin(); it != frames.rend(); it++) {
s = LetStmt::make((*it)->name, (*it)->value, s);
for (const auto *let : reverse_view(frames)) {
s = LetStmt::make(let->name, let->value, s);
}

return s;
Expand Down
4 changes: 2 additions & 2 deletions src/BoundConstantExtentLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ class BoundLoops : public IRMutator {
if (e == nullptr) {
// We're about to hard fail. Get really aggressive
// with the simplifier.
for (auto it = lets.rbegin(); it != lets.rend(); it++) {
extent = Let::make(it->first, it->second, extent);
for (const auto &[var, value] : reverse_view(lets)) {
extent = Let::make(var, value, extent);
}
extent = remove_likelies(extent);
extent = substitute_in_all_lets(extent);
Expand Down
4 changes: 2 additions & 2 deletions src/BoundSmallAllocations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ class BoundSmallAllocations : public IRMutator {

result = mutate(result);

for (auto it = frames.rbegin(); it != frames.rend(); it++) {
result = T::make(it->op->name, it->op->value, result);
for (const auto &frame : reverse_view(frames)) {
result = T::make(frame.op->name, frame.op->value, result);
}

return result;
Expand Down
56 changes: 28 additions & 28 deletions src/Bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2104,14 +2104,14 @@ class SolveIfThenElse : public IRMutator {
Stmt s = mutate(body);

if (s.same_as(body)) {
for (auto it = frames.rbegin(); it != frames.rend(); it++) {
pop_var((*it)->name);
for (const auto *frame : reverse_view(frames)) {
pop_var(frame->name);
}
return orig;
} else {
for (auto it = frames.rbegin(); it != frames.rend(); it++) {
pop_var((*it)->name);
s = LetStmt::make((*it)->name, (*it)->value, s);
for (const auto *frame : reverse_view(frames)) {
pop_var(frame->name);
s = LetStmt::make(frame->name, frame->value, s);
}
return s;
}
Expand Down Expand Up @@ -2590,64 +2590,64 @@ class BoxesTouched : public IRGraphVisitor {

result.accept(this);

for (auto it = frames.rbegin(); it != frames.rend(); it++) {
for (const auto &frame : reverse_view(frames)) {
// Pop the value bounds
scope.pop(it->op->name);
scope.pop(frame.op->name);

if (it->op->value.type() == type_of<struct halide_buffer_t *>()) {
buffer_lets.erase(it->op->name);
if (frame.op->value.type() == type_of<struct halide_buffer_t *>()) {
buffer_lets.erase(frame.op->name);
}

if (!it->min_name.empty()) {
if (!frame.min_name.empty()) {
// We made up new names for the bounds of the
// value, and need to rewrap any boxes we're
// returning with appropriate lets.
for (pair<const string, Box> &i : boxes) {
Box &box = i.second;
for (size_t i = 0; i < box.size(); i++) {
if (box[i].has_lower_bound()) {
if (expr_uses_var(box[i].min, it->max_name)) {
box[i].min = Let::make(it->max_name, it->value_bounds.max, box[i].min);
if (expr_uses_var(box[i].min, frame.max_name)) {
box[i].min = Let::make(frame.max_name, frame.value_bounds.max, box[i].min);
}
if (expr_uses_var(box[i].min, it->min_name)) {
box[i].min = Let::make(it->min_name, it->value_bounds.min, box[i].min);
if (expr_uses_var(box[i].min, frame.min_name)) {
box[i].min = Let::make(frame.min_name, frame.value_bounds.min, box[i].min);
}
}
if (box[i].has_upper_bound()) {
if (expr_uses_var(box[i].max, it->max_name)) {
box[i].max = Let::make(it->max_name, it->value_bounds.max, box[i].max);
if (expr_uses_var(box[i].max, frame.max_name)) {
box[i].max = Let::make(frame.max_name, frame.value_bounds.max, box[i].max);
}
if (expr_uses_var(box[i].max, it->min_name)) {
box[i].max = Let::make(it->min_name, it->value_bounds.min, box[i].max);
if (expr_uses_var(box[i].max, frame.min_name)) {
box[i].max = Let::make(frame.min_name, frame.value_bounds.min, box[i].max);
}
}
}
}
}

if (is_let_stmt::value) {
let_stmts.pop(it->op->name);
let_stmts.pop(frame.op->name);

// If this let stmt shadowed an outer one, we need
// to re-insert the children from the previous let
// stmt into the map.
if (!it->old_let_vars.empty()) {
internal_assert(it->vi.instance > 0);
VarInstance old_vi = VarInstance(it->vi.var, it->vi.instance - 1);
for (const auto &v : it->old_let_vars) {
if (!frame.old_let_vars.empty()) {
internal_assert(frame.vi.instance > 0);
VarInstance old_vi = VarInstance(frame.vi.var, frame.vi.instance - 1);
for (const auto &v : frame.old_let_vars) {
internal_assert(vars_renaming.count(v));
children[get_var_instance(v)].insert(old_vi);
}
}

// Remove the children from the current let stmt.
for (const auto &v : it->collect.vars) {
for (const auto &v : frame.collect.vars) {
internal_assert(vars_renaming.count(v));
children[get_var_instance(v)].erase(it->vi);
children[get_var_instance(v)].erase(frame.vi);
}
}

pop_var(it->op->name);
pop_var(frame.op->name);
}
}

Expand Down Expand Up @@ -3151,8 +3151,8 @@ map<string, Box> boxes_touched(const Expr &e, Stmt s, bool consider_calls, bool
return s;
} else {
// Rewrap the lets around the mutated body
for (auto it = frames.rbegin(); it != frames.rend(); it++) {
s = LetStmt::make((*it)->name, (*it)->value, s);
for (const auto *frame : reverse_view(frames)) {
s = LetStmt::make(frame->name, frame->value, s);
}
return s;
}
Expand Down
16 changes: 7 additions & 9 deletions src/BoundsInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,9 @@ class BoundsInference : public IRMutator {
}
}

const vector<Specialization> &specializations = def.specializations();
for (size_t i = specializations.size(); i > 0; i--) {
Expr s_cond = specializations[i - 1].condition;
const Definition &s_def = specializations[i - 1].definition;
for (const auto &s : reverse_view(def.specializations())) {
const Expr s_cond = s.condition;
const Definition &s_def = s.definition;

// Else case (i.e. specialization condition is false)
for (auto &vec : result) {
Expand Down Expand Up @@ -1309,12 +1308,11 @@ class BoundsInference : public IRMutator {
old_inner_productions.end());

// Rewrap the let/if statements
for (size_t i = wrappers.size(); i > 0; i--) {
const auto &p = wrappers[i - 1];
if (p.first.empty()) {
body = IfThenElse::make(p.second, body);
for (const auto &[var, value] : reverse_view(wrappers)) {
if (var.empty()) {
body = IfThenElse::make(value, body);
} else {
body = LetStmt::make(p.first, p.second, body);
body = LetStmt::make(var, value, body);
}
}

Expand Down
16 changes: 8 additions & 8 deletions src/CPlusPlusMangle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,14 @@ MangledNamePart mangle_inner_name(const Type &type, const Target &target, Previo
result.full_name = quals + code + type.handle_type->inner_name.name + "@";
result.with_substitutions = quals + code + prev_decls.check_and_enter_name(type.handle_type->inner_name.name);

for (size_t i = type.handle_type->enclosing_types.size(); i > 0; i--) {
result.full_name += type.handle_type->enclosing_types[i - 1].name + "@";
result.with_substitutions += prev_decls.check_and_enter_name(type.handle_type->enclosing_types[i - 1].name);
for (const auto &enclosing_type : reverse_view(type.handle_type->enclosing_types)) {
result.full_name += enclosing_type.name + "@";
result.with_substitutions += prev_decls.check_and_enter_name(enclosing_type.name);
}

for (size_t i = type.handle_type->namespaces.size(); i > 0; i--) {
result.full_name += type.handle_type->namespaces[i - 1] + "@";
result.with_substitutions += prev_decls.check_and_enter_name(type.handle_type->namespaces[i - 1]);
for (const auto &ns : reverse_view(type.handle_type->namespaces)) {
result.full_name += ns + "@";
result.with_substitutions += prev_decls.check_and_enter_name(ns);
}

result.full_name += "@";
Expand Down Expand Up @@ -288,8 +288,8 @@ std::string cplusplus_function_mangled_name(const std::string &name, const std::
PreviousDeclarations prev_decls;
result += prev_decls.check_and_enter_name(name);

for (size_t i = namespaces.size(); i > 0; i--) {
result += prev_decls.check_and_enter_name(namespaces[i - 1]);
for (const auto &ns : reverse_view(namespaces)) {
result += prev_decls.check_and_enter_name(ns);
}
result += "@";

Expand Down
10 changes: 4 additions & 6 deletions src/CSE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ class CSEEveryExprInStmt : public IRMutator {
internal_assert(bundle && bundle->args.size() == 2);
Stmt s = Store::make(op->name, bundle->args[0], bundle->args[1],
op->param, mutate(op->predicate), op->alignment);
for (auto it = lets.rbegin(); it != lets.rend(); it++) {
s = LetStmt::make(it->first, it->second, s);
for (const auto &[var, value] : reverse_view(lets)) {
s = LetStmt::make(var, value, s);
}
return s;
}
Expand Down Expand Up @@ -336,13 +336,11 @@ Expr common_subexpression_elimination(const Expr &e_in, bool lift_all) {
debug(4) << "With variables " << e << "\n";

// Wrap the final expr in the lets.
for (size_t i = lets.size(); i > 0; i--) {
Expr value = lets[i - 1].second;
for (const auto &[var, value] : reverse_view(lets)) {
// Drop this variable as an acceptable replacement for this expr.
replacer.erase(value);
// Use containing lets in the value.
value = replacer.mutate(lets[i - 1].second);
e = Let::make(lets[i - 1].first, value, e);
e = Let::make(var, replacer.mutate(value), e);
}

debug(4) << "With lets: " << e << "\n";
Expand Down
10 changes: 5 additions & 5 deletions src/CanonicalizeGPUVars.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,13 @@ class CanonicalizeGPUVars : public IRMutator {

result = mutate(result);

for (auto it = lets.rbegin(); it != lets.rend(); it++) {
std::string name = canonicalize_let(it->first);
if (name != it->first) {
for (const auto &[var, value] : reverse_view(lets)) {
std::string name = canonicalize_let(var);
if (name != var) {
Expr new_var = Variable::make(Int(32), name);
result = substitute(it->first, new_var, result);
result = substitute(var, new_var, result);
}
result = LetStmt::make(name, it->second, result);
result = LetStmt::make(name, value, result);
}

return result;
Expand Down
4 changes: 2 additions & 2 deletions src/CodeGen_C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1123,8 +1123,8 @@ void CodeGen_C::compile(const LoweredFunc &f, const MetadataNameMap &metadata_na

if (!namespaces.empty()) {
stream << "\n";
for (size_t i = namespaces.size(); i > 0; i--) {
stream << "} // namespace " << namespaces[i - 1] << "\n";
for (const auto &ns : reverse_view(namespaces)) {
stream << "} // namespace " << ns << "\n";
}
stream << "\n";
}
Expand Down
3 changes: 1 addition & 2 deletions src/CodeGen_LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3245,8 +3245,7 @@ void CodeGen_LLVM::visit(const Call *op) {
// Build the not-already-inited case
builder->SetInsertPoint(global_not_inited_bb);
llvm::Value *selected_value = nullptr;
for (int i = sub_fns.size() - 1; i >= 0; i--) {
const auto &sub_fn = sub_fns[i];
for (const auto &sub_fn : reverse_view(sub_fns)) {
if (!selected_value) {
selected_value = sub_fn.fn_ptr;
} else {
Expand Down
4 changes: 2 additions & 2 deletions src/CodeGen_PyTorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ void CodeGen_PyTorch::compile(const LoweredFunc &f, bool is_cuda) {

if (!namespaces.empty()) {
stream << "\n";
for (size_t i = namespaces.size(); i > 0; i--) {
stream << "} // namespace " << namespaces[i - 1] << "\n";
for (const auto &ns : reverse_view(namespaces)) {
stream << "} // namespace " << ns << "\n";
}
stream << "\n";
}
Expand Down
16 changes: 8 additions & 8 deletions src/Deinterleave.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,21 +481,21 @@ class Interleaver : public IRMutator {

result = mutate(result);

for (auto it = frames.rbegin(); it != frames.rend(); it++) {
Expr value = std::move(it->new_value);
for (const auto &frame : reverse_view(frames)) {
Expr value = std::move(frame.new_value);

result = T::make(it->op->name, value, result);
result = T::make(frame.op->name, value, result);

// For vector lets, we may additionally need a let defining the even and odd lanes only
if (value.type().is_vector()) {
if (value.type().lanes() % 2 == 0) {
result = T::make(it->op->name + ".even_lanes", extract_even_lanes(value, vector_lets), result);
result = T::make(it->op->name + ".odd_lanes", extract_odd_lanes(value, vector_lets), result);
result = T::make(frame.op->name + ".even_lanes", extract_even_lanes(value, vector_lets), result);
result = T::make(frame.op->name + ".odd_lanes", extract_odd_lanes(value, vector_lets), result);
}
if (value.type().lanes() % 3 == 0) {
result = T::make(it->op->name + ".lanes_0_of_3", extract_mod3_lanes(value, 0, vector_lets), result);
result = T::make(it->op->name + ".lanes_1_of_3", extract_mod3_lanes(value, 1, vector_lets), result);
result = T::make(it->op->name + ".lanes_2_of_3", extract_mod3_lanes(value, 2, vector_lets), result);
result = T::make(frame.op->name + ".lanes_0_of_3", extract_mod3_lanes(value, 0, vector_lets), result);
result = T::make(frame.op->name + ".lanes_1_of_3", extract_mod3_lanes(value, 1, vector_lets), result);
result = T::make(frame.op->name + ".lanes_2_of_3", extract_mod3_lanes(value, 2, vector_lets), result);
}
}
}
Expand Down
Loading

0 comments on commit 922e469

Please sign in to comment.