Skip to content

Commit

Permalink
Stop interleaver from expanding the scope of letstmts (#7908)
Browse files Browse the repository at this point in the history
In the following code:

let a = b in
  X
let a = c in
  Y

If Stmt X successfully had stores interleaved, it was re-nesting it like
so:

let a = b in
  X
  let a = c in
    Y

This introduces a shadowed variable 'a', which is illegal at this stage
of lowering.

Fixes #7906

Also some drive-by fixes to earlier tests that had debugging code left
in.
  • Loading branch information
abadams authored Oct 20, 2023
1 parent eb66c06 commit bd1d4df
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
10 changes: 5 additions & 5 deletions src/Deinterleave.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -759,17 +759,17 @@ class Interleaver : public IRMutator {
Expr predicate = Shuffle::make_interleave(predicates);
Stmt new_store = Store::make(store->name, value, index, store->param, predicate, ModulusRemainder());

// Continue recursively into the stuff that
// collect_strided_stores didn't collect.
Stmt stmt = Block::make(new_store, mutate(rest));

// Rewrap the let statements we pulled off.
while (!let_stmts.empty()) {
const LetStmt *let = let_stmts.back().as<LetStmt>();
stmt = LetStmt::make(let->name, let->value, stmt);
new_store = LetStmt::make(let->name, let->value, new_store);
let_stmts.pop_back();
}

// Continue recursively into the stuff that
// collect_strided_stores didn't collect.
Stmt stmt = Block::make(new_store, mutate(rest));

// Success!
return stmt;
}
Expand Down
25 changes: 22 additions & 3 deletions test/correctness/fuzz_schedule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,8 @@ int main(int argc, char **argv) {
Var xo, xi;
blurry.split(x, xo, xi, 2, TailStrategy::GuardWithIf);
local_sum.store_at(blurry, y).compute_at(blurry, xi);
// local_sum.store_root();
blurry.bound(y, 0, 1);
Pipeline p({blurry});
Buffer<int> buf = p.realize({4, 1});
Buffer<int> buf = p.realize({32, 32});
check_blur_output(buf, correct);
}

Expand Down Expand Up @@ -140,6 +138,27 @@ int main(int argc, char **argv) {
local_sum.update(0).unscheduled();
Pipeline p({blurry});
Buffer<int> buf = p.realize({32, 32});
check_blur_output(buf, correct);
}

// https://github.com/halide/Halide/issues/7906
{
Func input("input");
Func local_sum("local_sum");
Func blurry("blurry");
Var x("x"), y("y");
input(x, y) = 2 * x + 5 * y;
RDom r(-2, 5, -2, 5);
local_sum(x, y) = 0;
local_sum(x, y) += input(x + r.x, y + r.y);
blurry(x, y) = cast<int32_t>(local_sum(x, y) / 25);
Var yo, yi, x_yo_f;
input.vectorize(y).split(y, yo, yi, 2, TailStrategy::ShiftInwards).unroll(x).fuse(x, yo, x_yo_f);
blurry.compute_root();
input.compute_at(blurry, x);
Pipeline p({blurry});
Buffer<int> buf = p.realize({32, 32});
check_blur_output(buf, correct);
}

printf("Success!\n");
Expand Down

0 comments on commit bd1d4df

Please sign in to comment.