Skip to content

Commit

Permalink
Merge branch 'main' into pr/7848
Browse files Browse the repository at this point in the history
  • Loading branch information
steven-johnson committed Sep 25, 2023
2 parents d92053b + 05d5efa commit 3bd60aa
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,8 @@ Stmt IfThenElse::make(Expr condition, Stmt then_case, Stmt else_case) {
internal_assert(condition.defined() && then_case.defined()) << "IfThenElse of undefined\n";
// else_case may be null.

internal_assert(condition.type().is_scalar()) << "IfThenElse with vector condition\n";

IfThenElse *node = new IfThenElse;
node->condition = std::move(condition);
node->then_case = std::move(then_case);
Expand Down
2 changes: 1 addition & 1 deletion src/PartitionLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ class FindSimplifications : public IRVisitor {
}
condition = remove_likelies(condition);
Simplification s = {condition, std::move(old), std::move(likely_val), std::move(unlikely_val), true};
if (s.condition.type().is_vector()) {
while (s.condition.type().is_vector()) {
s.condition = simplify(s.condition);
if (const Broadcast *b = s.condition.as<Broadcast>()) {
s.condition = b->value;
Expand Down
4 changes: 4 additions & 0 deletions src/Simplify_Stmts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,10 @@ Stmt Simplify::visit(const Store *op) {

const Load *load = value.as<Load>();
const Broadcast *scalar_pred = predicate.as<Broadcast>();
if (scalar_pred && !scalar_pred->value.type().is_scalar()) {
// Nested vectorization
scalar_pred = nullptr;
}

ModulusRemainder align = ModulusRemainder::intersect(op->alignment, base_info.alignment);

Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ tests(GROUPS correctness
fuse_gpu_threads.cpp
fused_where_inner_extent_is_zero.cpp
fuzz_float_stores.cpp
fuzz_schedule.cpp
gameoflife.cpp
gather.cpp
gpu_allocation_cache.cpp
Expand Down
60 changes: 60 additions & 0 deletions test/correctness/fuzz_schedule.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include "Halide.h"

using namespace Halide;

void check_blur_output(const Buffer<int> &out, const Buffer<int> &correct) {
for (int y = 0; y < out.height(); y++) {
for (int x = 0; x < out.width(); x++) {
if (out(x, y) != correct(x, y)) {
printf("out(%d, %d) = %d instead of %d\n",
x, y, out(x, y), correct(x, y));
exit(1);
}
}
}
}

int main(int argc, char **argv) {
// This test is for schedules that crash the compiler found via fuzzing that
// are hard to otherwise reproduce. We don't need to check the output.

Buffer<int> correct;
{
// An unscheduled instance to act as a reference output
Func input("input");
Func local_sum("local_sum");
Func blurry("blurry");
Var x("x"), y("y");
input(x, y) = 2 * x + 5 * y;
RDom r(-2, 5, -2, 5);
local_sum(x, y) = 0;
local_sum(x, y) += input(x + r.x, y + r.y);
blurry(x, y) = cast<int32_t>(local_sum(x, y) / 25);
correct = blurry.realize({32, 32});
}

// https://github.com/halide/Halide/issues/7851
{
Func input("input");
Func local_sum("local_sum");
Func blurry("blurry");
Var x("x"), y("y");
input(x, y) = 2 * x + 5 * y;
RDom r(-2, 5, -2, 5);
local_sum(x, y) = 0;
local_sum(x, y) += input(x + r.x, y + r.y);
blurry(x, y) = cast<int32_t>(local_sum(x, y) / 25);
Var yo("yo"), yi("yi"), xo("xo"), xi("xi"), yo_x_f("yo_x_f"), yo_x_fo("yo_x_fo"), yo_x_fi("yo_x_fi");
blurry.split(y, yo, yi, 2, TailStrategy::RoundUp).fuse(yo, x, yo_x_f).vectorize(yi).split(yo_x_f, yo_x_fo, yo_x_fi, 2, TailStrategy::Predicate).reorder(yo_x_fo, yo_x_fi, yi);
input.split(y, yo, yi, 2, TailStrategy::PredicateStores).fuse(yo, x, yo_x_f).vectorize(yi).split(yo_x_f, yo_x_fo, yo_x_fi, 2, TailStrategy::Predicate).reorder(yo_x_fo, yo_x_fi, yi);
blurry.store_root();
input.compute_at(blurry, yi);
Pipeline p({blurry});
Buffer<int> buf = p.realize({32, 32});
check_blur_output(buf, correct);
}

printf("Success!\n");

return 0;
}

0 comments on commit 3bd60aa

Please sign in to comment.