diff --git a/src/SlidingWindow.cpp b/src/SlidingWindow.cpp index 033556e350a1..479d71ce6fac 100644 --- a/src/SlidingWindow.cpp +++ b/src/SlidingWindow.cpp @@ -266,10 +266,6 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator { string prefix = func.name() + ".s" + std::to_string(func.updates().size()) + "."; const std::vector func_args = func.args(); for (int i = 0; i < func.dimensions(); i++) { - if (slid_dimensions.count(i)) { - debug(3) << "Already slid over dimension " << i << ", so skipping it.\n"; - continue; - } // Look up the region required of this function's last stage string var = prefix + func_args[i]; internal_assert(scope.contains(var + ".min") && scope.contains(var + ".max")); @@ -304,6 +300,12 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator { } } + if (!dim.empty() && slid_dimensions.count(dim_idx)) { + debug(1) << "Already slid over dimension " << dim_idx << ", so skipping it.\n"; + dim = ""; + min_required = Expr(); + max_required = Expr(); + } if (!min_required.defined()) { debug(3) << "Could not perform sliding window optimization of " << func.name() << " over " << loop_var << " because multiple " diff --git a/test/correctness/fuzz_schedule.cpp b/test/correctness/fuzz_schedule.cpp index 60a780a89d6e..d5a2a664fec5 100644 --- a/test/correctness/fuzz_schedule.cpp +++ b/test/correctness/fuzz_schedule.cpp @@ -74,6 +74,28 @@ int main(int argc, char **argv) { check_blur_output(buf, correct); } + // https://github.com/halide/Halide/issues/7872 + { + Func input("input"); + Func local_sum("local_sum"); + Func blurry("blurry"); + Var x("x"), y("y"); + RVar yryf; + input(x, y) = 2 * x + 5 * y; + RDom r(-2, 5, -2, 5, "rdom_r"); + local_sum(x, y) = 0; + local_sum(x, y) += input(x + r.x, y + r.y); + blurry(x, y) = cast(local_sum(x, y) / 25); + Var xo, xi; + blurry.split(x, xo, xi, 2, TailStrategy::GuardWithIf); + local_sum.store_at(blurry, y).compute_at(blurry, xi); + // local_sum.store_root(); + blurry.bound(y, 0, 1); + Pipeline p({blurry}); + Buffer buf = p.realize({4, 1}); + check_blur_output(buf, correct); + } + printf("Success!\n"); return 0;