Skip to content

Commit

Permalink
Fix read-after-write hazard analysis in storage folding (#7910)
Browse files Browse the repository at this point in the history
Explicitly mark which loops get loop-carry-dependencies inserted by
sliding window to assist storage folding.

Storage folding needs to know about this so it doesn't try to fold in a
way that invalidates these read-after-write dependencies. It currently
tries to prove the absence of hazards with box_contains(box_provided,
box_required), but this is sometimes incorrect because box_provided
could be conservatively large, and the code it analyses might not
actually provide (store to) all the required (loaded from) values.

It's simpler for sliding window to just tell storage folding when it
inserts loop-carry-dependencies, and this is most simply done directly
in the IR itself.

Fixes #7909
  • Loading branch information
abadams authored Oct 24, 2023
1 parent d023065 commit fffb8bd
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 9 deletions.
1 change: 1 addition & 0 deletions src/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,7 @@ const char *const intrinsic_op_names[] = {
"shift_right",
"signed_integer_overflow",
"size_of_halide_buffer_t",
"sliding_window_marker",
"sorted_avg",
"strict_float",
"stringify",
Expand Down
9 changes: 9 additions & 0 deletions src/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,15 @@ struct Call : public ExprNode<Call> {
signed_integer_overflow,
size_of_halide_buffer_t,

// Takes a realization name and a loop variable. Declares that values of
// the realization that were stored on earlier loop iterations of the
// given loop are potentially loaded in this loop iteration somewhere
// after this point. Must occur inside a Realize node and For node of
// the given names but outside any corresponding ProducerConsumer
// nodes. Communicates to storage folding that sliding window took
// place.
sliding_window_marker,

// Compute (arg[0] + arg[1]) / 2, assuming arg[0] < arg[1].
sorted_avg,
strict_float,
Expand Down
13 changes: 12 additions & 1 deletion src/SlidingWindow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,9 @@ class SlidingWindow : public IRMutator {
}
}

SlidingWindowOnFunctionAndLoop slider(func, name, prev_loop_min, slid_dimensions[func.name()]);
set<int> &slid_dims = slid_dimensions[func.name()];
size_t old_slid_dims_size = slid_dims.size();
SlidingWindowOnFunctionAndLoop slider(func, name, prev_loop_min, slid_dims);
body = slider.mutate(body);

if (func.schedule().memory_type() == MemoryType::Register &&
Expand Down Expand Up @@ -856,6 +858,15 @@ class SlidingWindow : public IRMutator {
new_lets.emplace_front(name + ".loop_min.orig", loop_min);
new_lets.emplace_front(name + ".loop_extent", (loop_max - loop_min) + 1);
}

if (slid_dims.size() > old_slid_dims_size) {
// Let storage folding know there's now a read-after-write hazard here
Expr marker = Call::make(Int(32),
Call::sliding_window_marker,
{func.name(), Variable::make(Int(32), op->name)},
Call::Intrinsic);
body = Block::make(Evaluate::make(marker), body);
}
}

body = mutate(body);
Expand Down
47 changes: 40 additions & 7 deletions src/StorageFolding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,28 @@ class AttemptStorageFoldingOfFunction : public IRMutator {
}
}

bool found_sliding_marker = false;
Expr visit(const Call *op) override {
if (op->is_intrinsic(Call::sliding_window_marker)) {
internal_assert(op->args.size() == 2);
const StringImm *name = op->args[0].as<StringImm>();
internal_assert(name);
if (name->value == func.name()) {
found_sliding_marker = true;
}
}
return op;
}

Stmt visit(const Block *op) override {
Stmt first = mutate(op->first);
if (found_sliding_marker) {
return Block::make(first, op->rest);
} else {
return Block::make(first, mutate(op->rest));
}
}

Stmt visit(const For *op) override {
if (op->for_type != ForType::Serial && op->for_type != ForType::Unrolled) {
// We can't proceed into a parallel for loop.
Expand Down Expand Up @@ -878,12 +900,10 @@ class AttemptStorageFoldingOfFunction : public IRMutator {
}
}

// If there's no communication of values from one loop
// iteration to the next (which may happen due to sliding),
// then we're safe to fold an inner loop.
if (box_contains(provided, required)) {
body = mutate(body);
}
// Attempt to fold an inner loop. This will bail out if it encounters a
// ProducerConsumer node for the func, or if it hits a sliding window
// marker.
body = mutate(body);

if (body.same_as(op->body)) {
stmt = op;
Expand Down Expand Up @@ -1010,10 +1030,23 @@ class StorageFolding : public IRMutator {
}
};

class RemoveSlidingWindowMarkers : public IRMutator {
using IRMutator::visit;
Expr visit(const Call *op) override {
if (op->is_intrinsic(Call::sliding_window_marker)) {
return make_zero(op->type);
} else {
return IRMutator::visit(op);
}
}
};

} // namespace

Stmt storage_folding(const Stmt &s, const std::map<std::string, Function> &env) {
return StorageFolding(env).mutate(s);
Stmt stmt = StorageFolding(env).mutate(s);
stmt = RemoveSlidingWindowMarkers().mutate(stmt);
return stmt;
}

} // namespace Internal
Expand Down
23 changes: 22 additions & 1 deletion test/correctness/fuzz_schedule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,28 @@ int main(int argc, char **argv) {
check_blur_output(buf, correct);
}

printf("Success!\n");
// https://github.com/halide/Halide/issues/7909
{
Func input("input");
Func local_sum("local_sum");
Func blurry("blurry");
Var x("x"), y("y");
input(x, y) = 2 * x + 5 * y;
RDom r(-2, 5, -2, 5);
local_sum(x, y) = 0;
local_sum(x, y) += input(x + r.x, y + r.y);
blurry(x, y) = cast<int32_t>(local_sum(x, y) / 25);
Var yo, yi;
blurry.split(y, yo, yi, 1, TailStrategy::Auto);
local_sum.compute_at(blurry, yo);
local_sum.store_root();
input.compute_at(local_sum, x);
input.store_root();
Pipeline p({blurry});
Buffer<int> buf = p.realize({32, 32});
check_blur_output(buf, correct);
}

printf("Success!\n");
return 0;
}

0 comments on commit fffb8bd

Please sign in to comment.