From c2f87bce507a2c364858483ae47a10cb708efd96 Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Mon, 16 Dec 2024 12:30:17 -0800 Subject: [PATCH 01/16] Remove unused function in HexagonOptimize (#8511) --- src/HexagonOptimize.cpp | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/src/HexagonOptimize.cpp b/src/HexagonOptimize.cpp index 92f3965e1600..d75231813215 100644 --- a/src/HexagonOptimize.cpp +++ b/src/HexagonOptimize.cpp @@ -141,24 +141,6 @@ string type_suffix(const vector &ops, bool signed_variants) { namespace { -// Helper to handle various forms of multiplication. -Expr as_mul(const Expr &a) { - if (a.as()) { - return a; - } else if (const Call *wm = Call::as_intrinsic(a, {Call::widening_mul})) { - return simplify(Mul::make(cast(wm->type, wm->args[0]), cast(wm->type, wm->args[1]))); - } else if (const Call *s = Call::as_intrinsic(a, {Call::shift_left, Call::widening_shift_left})) { - auto log2_b = as_const_uint(s->args[1]); - if (log2_b) { - Expr b = make_one(s->type) << cast(UInt(s->type.bits()), (int)*log2_b); - return simplify(Mul::make(cast(s->type, s->args[0]), b)); - } - } else if (const Call *wm = Call::as_intrinsic(a, {Call::widen_right_mul})) { - return simplify(Mul::make(wm->args[0], cast(wm->type, wm->args[1]))); - } - return Expr(); -} - // Helpers to generate horizontally reducing multiply operations. Expr halide_hexagon_add_2mpy(Type result_type, const string &suffix, Expr v0, Expr v1, Expr c0, Expr c1) { Expr call = Call::make(result_type, "halide.hexagon.add_2mpy" + suffix, From c6458ff0ff50a0f359250f04e81c48ba02deb82b Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 16 Dec 2024 12:30:41 -0800 Subject: [PATCH 02/16] Limit depth more strictly in CSE fuzz test (#8512) --- test/fuzz/cse.cpp | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/test/fuzz/cse.cpp b/test/fuzz/cse.cpp index 44a5e2c06a08..1874bef94968 100644 --- a/test/fuzz/cse.cpp +++ b/test/fuzz/cse.cpp @@ -8,19 +8,24 @@ using namespace Halide; using namespace Halide::ConciseCasts; using namespace Halide::Internal; +using std::pair; using std::vector; // Note that this deliberately uses int16 values everywhere -- // *not* int32 -- because we want to test CSE, not the simplifier's // overflow behavior, and using int32 can end up with results // containing signed_integer_overflow(), which is not helpful here. -Expr random_expr(FuzzedDataProvider &fdp, int depth, vector &exprs) { +Expr random_expr(FuzzedDataProvider &fdp, int depth, vector> &exprs) { if (depth <= 0) { return i16(fdp.ConsumeIntegralInRange(-5, 4)); } if (!exprs.empty() && fdp.ConsumeBool()) { - // Reuse an existing expression - return pick_value_in_vector(fdp, exprs); + // Reuse an existing expression that was generated under conditions at + // least as strict as our current depth limit. + auto p = pick_value_in_vector(fdp, exprs); + if (p.second <= depth) { + return p.first; + } } std::function build_next_expr[] = { [&]() { @@ -67,13 +72,13 @@ Expr random_expr(FuzzedDataProvider &fdp, int depth, vector &exprs) { }, }; Expr next = fdp.PickValueInArray(build_next_expr)(); - exprs.push_back(next); + exprs.emplace_back(next, depth); return next; } extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { FuzzedDataProvider fdp(data, size); - vector exprs; + vector> exprs; Expr orig = random_expr(fdp, 5, exprs); Expr csed = common_subexpression_elimination(orig); From 153e35d929b92e9d08ec9a5aadf784d36209acc5 Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Tue, 17 Dec 2024 11:40:44 -0500 Subject: [PATCH 03/16] Fix workflow for next release (#8514) --- .github/workflows/pip.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pip.yml b/.github/workflows/pip.yml index c83cff317555..a0eaf790c76c 100644 --- a/.github/workflows/pip.yml +++ b/.github/workflows/pip.yml @@ -6,9 +6,9 @@ name: Build PyPI package on: push: - branches: [ main, build/pip-packaging ] + branches: [ main ] release: - types: [ created ] + types: [ published ] env: # TODO: detect this from repo somehow: https://github.com/halide/Halide/issues/8406 @@ -253,7 +253,7 @@ jobs: - uses: pypa/gh-action-pypi-publish@release/v1 if: github.event_name == 'push' && github.ref_name == 'main' with: - repository_url: https://test.pypi.org/legacy/ + repository-url: https://test.pypi.org/legacy/ - uses: pypa/gh-action-pypi-publish@release/v1 if: github.event_name == 'release' && github.event.action == 'published' From d3f19bda0dd92bf58acc69c4df6f44d4511ddd03 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Wed, 18 Dec 2024 09:53:58 -0800 Subject: [PATCH 04/16] Fix two non-idiomatic uses of node_type (#8520) --- src/Derivative.cpp | 6 ++---- src/Serialization.cpp | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/Derivative.cpp b/src/Derivative.cpp index 2520d27e290f..e64fb4ada94b 100644 --- a/src/Derivative.cpp +++ b/src/Derivative.cpp @@ -337,8 +337,7 @@ void ReverseAccumulationVisitor::propagate_adjoints( let_var_mapping.clear(); let_variables.clear(); for (const auto &expr : expr_list) { - if (expr.get()->node_type == IRNodeType::Let) { - const Let *op = expr.as(); + if (const Let *op = expr.as()) { // Assume Let variables are unique internal_assert(let_var_mapping.find(op->name) == let_var_mapping.end()); let_var_mapping[op->name] = op->value; @@ -660,8 +659,7 @@ void ReverseAccumulationVisitor::propagate_adjoints( let_var_mapping.clear(); let_variables.clear(); for (const auto &expr : expr_list) { - if (expr.get()->node_type == IRNodeType::Let) { - const Let *op = expr.as(); + if (const Let *op = expr.as()) { // Assume Let variables are unique internal_assert(let_var_mapping.find(op->name) == let_var_mapping.end()); let_var_mapping[op->name] = op->value; diff --git a/src/Serialization.cpp b/src/Serialization.cpp index 27bea2f5cfa3..15722d878974 100644 --- a/src/Serialization.cpp +++ b/src/Serialization.cpp @@ -410,7 +410,7 @@ std::pair> Serializer::serialize_stmt(FlatBufferBu if (!stmt.defined()) { return std::make_pair(Serialize::Stmt::UndefinedStmt, Serialize::CreateUndefinedStmt(builder).Union()); } - switch (stmt->node_type) { + switch (stmt.node_type()) { case IRNodeType::LetStmt: { const auto *const let_stmt = stmt.as(); const auto name_serialized = serialize_string(builder, let_stmt->name); @@ -681,7 +681,7 @@ std::pair> Serializer::serialize_expr(FlatBufferBu if (!expr.defined()) { return std::make_pair(Serialize::Expr::UndefinedExpr, Serialize::CreateUndefinedExpr(builder).Union()); } - switch (expr->node_type) { + switch (expr.node_type()) { case IRNodeType::IntImm: { const auto *const int_imm = expr.as(); const auto type_serialized = serialize_type(builder, int_imm->type); From ac2cd23951aff9ac3b765e51938f1e576f1f0ee9 Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Thu, 19 Dec 2024 12:53:21 -0500 Subject: [PATCH 05/16] Fix Debian packaging (#8524) * Handle DESTDIR in HalidePackageConfigHelpers.cmake Fixes #8521 * Fix Halide_[SO]VERSION_OVERRIDE when value is 0 Fixes #8522 * Upgrade LLVM to 19.1.6 --- .github/workflows/pip.yml | 2 +- cmake/HalidePackageConfigHelpers.cmake | 4 ++-- src/CMakeLists.txt | 9 +++++---- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/.github/workflows/pip.yml b/.github/workflows/pip.yml index a0eaf790c76c..3a49a9f94528 100644 --- a/.github/workflows/pip.yml +++ b/.github/workflows/pip.yml @@ -12,7 +12,7 @@ on: env: # TODO: detect this from repo somehow: https://github.com/halide/Halide/issues/8406 - LLVM_VERSION: 19.1.4 + LLVM_VERSION: 19.1.6 FLATBUFFERS_VERSION: 23.5.26 WABT_VERSION: 1.0.36 diff --git a/cmake/HalidePackageConfigHelpers.cmake b/cmake/HalidePackageConfigHelpers.cmake index 4efe5ebf8718..efdacba24a1f 100644 --- a/cmake/HalidePackageConfigHelpers.cmake +++ b/cmake/HalidePackageConfigHelpers.cmake @@ -80,7 +80,7 @@ function(_Halide_install_pkgdeps) set(depFile "${CMAKE_CURRENT_BINARY_DIR}/${ARG_FILE_NAME}") _Halide_install_code( - "file(READ \"\${CMAKE_INSTALL_PREFIX}/${ARG_DESTINATION}/${ARG_EXPORT_FILE}\" target_cmake)" + "file(READ \"\$ENV{DESTDIR}\${CMAKE_INSTALL_PREFIX}/${ARG_DESTINATION}/${ARG_EXPORT_FILE}\" target_cmake)" "file(WRITE \"${depFile}.in\" \"\")" ) @@ -104,4 +104,4 @@ function(_Halide_install_pkgdeps) DESTINATION "${ARG_DESTINATION}" COMPONENT "${ARG_COMPONENT}" ) -endfunction() \ No newline at end of file +endfunction() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f5eb3e64f97f..86c298246d97 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -18,8 +18,10 @@ set(Halide_VERSION_OVERRIDE "${Halide_VERSION}" CACHE STRING "VERSION to set for custom Halide packaging") mark_as_advanced(Halide_VERSION_OVERRIDE) -if (Halide_VERSION_OVERRIDE) - # Empty is considered a value distinct from not-defined +if (NOT Halide_VERSION_OVERRIDE STREQUAL "") + # CMake treats an empty VERSION property differently from leaving it unset. + # We also can't check the boolean-ness of Halide_VERSION_OVERRIDE because + # VERSION 0 is valid. See: https://github.com/halide/Halide/issues/8522 set_target_properties(Halide PROPERTIES VERSION "${Halide_VERSION_OVERRIDE}") endif () @@ -27,8 +29,7 @@ set(Halide_SOVERSION_OVERRIDE "${Halide_VERSION_MAJOR}" CACHE STRING "SOVERSION to set for custom Halide packaging") mark_as_advanced(Halide_SOVERSION_OVERRIDE) -if (Halide_SOVERSION_OVERRIDE) - # Empty is considered a value distinct from not-defined +if (NOT Halide_SOVERSION_OVERRIDE STREQUAL "") set_target_properties(Halide PROPERTIES SOVERSION "${Halide_SOVERSION_OVERRIDE}") endif () From bc9dfbf99de9d52578d28967a31aad71104e068a Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Fri, 20 Dec 2024 09:42:51 -0800 Subject: [PATCH 06/16] Handle some misc TODOs (#8528) CodeGen_LLVM.cpp: opaque pointers are now standard, and that flag no longer works. Var.h: We convert strings to Vars in many places internally, and some of those Vars originated from implicit Vars, so it's not feasible to require than the version that takes an explicit string isn't allowed to be passed things of the form "_[0-9]*". You can use the explicit constructor to make collisions with other Vars, and yes this includes the implicit vars. --- src/CodeGen_LLVM.cpp | 3 --- src/Deinterleave.cpp | 11 +++++++++-- src/IRMatch.h | 9 ++------- src/IROperator.cpp | 32 ++++++++++++-------------------- src/Var.h | 7 +++---- 5 files changed, 26 insertions(+), 36 deletions(-) diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index 31dddc3551af..a30b30ab3276 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -234,9 +234,6 @@ void CodeGen_LLVM::initialize_llvm() { for (const std::string &s : arg_vec) { c_arg_vec.push_back(s.c_str()); } - // TODO: Remove after opaque pointers become the default in LLVM. - // This is here to document how to turn on opaque pointers, for testing, in LLVM 15 - // c_arg_vec.push_back("-opaque-pointers"); cl::ParseCommandLineOptions((int)(c_arg_vec.size()), &c_arg_vec[0], "Halide compiler\n"); } diff --git a/src/Deinterleave.cpp b/src/Deinterleave.cpp index 255b2a67a339..54199485fed5 100644 --- a/src/Deinterleave.cpp +++ b/src/Deinterleave.cpp @@ -223,7 +223,7 @@ class Deinterleaver : public IRGraphMutator { } } if (op->value.type().lanes() > 1) { - // There is probably a more efficient way to this. + // There is probably a more efficient way to do this. return mutate(flatten_nested_ramps(op)); } @@ -236,7 +236,14 @@ class Deinterleaver : public IRGraphMutator { } else { Type t = op->type.with_lanes(new_lanes); ModulusRemainder align = op->alignment; - // TODO: Figure out the alignment of every nth lane + // The alignment of a Load refers to the alignment of the first + // lane, so we can preserve the existing alignment metadata if the + // deinterleave is asking for any subset of lanes that includes the + // first. Otherwise we just drop it. We could check if the index is + // a ramp with constant stride or some other special case, but if + // that's the case, the simplifier is very good at figuring out the + // alignment, and it has access to context (e.g. the alignment of + // enclosing lets) that we do not have here. if (starting_lane != 0) { align = ModulusRemainder(); } diff --git a/src/IRMatch.h b/src/IRMatch.h index afc264582651..676bd037bd27 100644 --- a/src/IRMatch.h +++ b/src/IRMatch.h @@ -1320,11 +1320,6 @@ constexpr bool and_reduce(bool first, Args... rest) { return first && and_reduce(rest...); } -// TODO: this can be replaced with std::min() once we require C++14 or later -constexpr int const_min(int a, int b) { - return a < b ? a : b; -} - template struct OptionalIntrinType { bool check(const Type &) const { @@ -1413,7 +1408,7 @@ struct Intrin { return saturating_cast(optional_type_hint.type, std::move(arg0)); } - Expr arg1 = std::get(args).make(state, type_hint); + Expr arg1 = std::get(1, sizeof...(Args) - 1)>(args).make(state, type_hint); if (intrin == Call::absd) { return absd(std::move(arg0), std::move(arg1)); } else if (intrin == Call::widen_right_add) { @@ -1448,7 +1443,7 @@ struct Intrin { return rounding_shift_right(std::move(arg0), std::move(arg1)); } - Expr arg2 = std::get(args).make(state, type_hint); + Expr arg2 = std::get(2, sizeof...(Args) - 1)>(args).make(state, type_hint); if (intrin == Call::mul_shift_right) { return mul_shift_right(std::move(arg0), std::move(arg1), std::move(arg2)); } else if (intrin == Call::rounding_mul_shift_right) { diff --git a/src/IROperator.cpp b/src/IROperator.cpp index 41ea946e10f4..26869c4d8b15 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -1048,49 +1048,41 @@ Expr strided_ramp_base(const Expr &e, int stride) { namespace { -struct RemoveLikelies : public IRMutator { +// Replace a specified list of intrinsics with their first arg. +class RemoveIntrinsics : public IRMutator { using IRMutator::visit; + const std::initializer_list &ops; + Expr visit(const Call *op) override { - if (op->is_intrinsic(Call::likely) || - op->is_intrinsic(Call::likely_if_innermost)) { + if (op->is_intrinsic(ops)) { return mutate(op->args[0]); } else { return IRMutator::visit(op); } } -}; -// TODO: There could just be one IRMutator that can remove -// calls from a list. If we need more of these, it might be worth -// doing that refactor. -struct RemovePromises : public IRMutator { - using IRMutator::visit; - Expr visit(const Call *op) override { - if (op->is_intrinsic(Call::promise_clamped) || - op->is_intrinsic(Call::unsafe_promise_clamped)) { - return mutate(op->args[0]); - } else { - return IRMutator::visit(op); - } +public: + RemoveIntrinsics(const std::initializer_list &ops) + : ops(ops) { } }; } // namespace Expr remove_likelies(const Expr &e) { - return RemoveLikelies().mutate(e); + return RemoveIntrinsics({Call::likely, Call::likely_if_innermost}).mutate(e); } Stmt remove_likelies(const Stmt &s) { - return RemoveLikelies().mutate(s); + return RemoveIntrinsics({Call::likely, Call::likely_if_innermost}).mutate(s); } Expr remove_promises(const Expr &e) { - return RemovePromises().mutate(e); + return RemoveIntrinsics({Call::promise_clamped, Call::unsafe_promise_clamped}).mutate(e); } Stmt remove_promises(const Stmt &s) { - return RemovePromises().mutate(s); + return RemoveIntrinsics({Call::promise_clamped, Call::unsafe_promise_clamped}).mutate(s); } Expr unwrap_tags(const Expr &e) { diff --git a/src/Var.h b/src/Var.h index 33207d3cea0d..fbb976b476d9 100644 --- a/src/Var.h +++ b/src/Var.h @@ -25,7 +25,9 @@ class Var { Expr e; public: - /** Construct a Var with the given name */ + /** Construct a Var with the given name. Unlike Funcs, this will be treated + * as the same Var as another other Var with the same name, including + * implicit Vars. */ Var(const std::string &n); /** Construct a Var with an automatically-generated unique name. */ @@ -120,9 +122,6 @@ class Var { static Var implicit(int n); /** Return whether a variable name is of the form for an implicit argument. - * TODO: This is almost guaranteed to incorrectly fire on user - * declared variables at some point. We should likely prevent - * user Var declarations from making names of this form. */ //{ static bool is_implicit(const std::string &name); From 526364f301e668da3b5d2fa8f59b412343d770c0 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Fri, 20 Dec 2024 11:13:24 -0800 Subject: [PATCH 07/16] Use lossless_cast for saturating casts from unsigned to signed on x86 (#8527) * Use lossless_cast to handle saturating casts from unsigned to signed in x86 This takes care of a TODO. I also moved that block of code into an if statement that only considers saturating casts, to avoid wasting time on it for all other call nodes. * Fix incorrect vector width in test --- src/CodeGen_X86.cpp | 81 ++++++++++++-------------- test/correctness/simd_op_check_x86.cpp | 6 ++ 2 files changed, 44 insertions(+), 43 deletions(-) diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 90609e1477c6..3d3f071b6b0f 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -647,7 +647,6 @@ void CodeGen_X86::visit(const Call *op) { Expr pattern; }; - // clang-format off static const Pattern patterns[] = { {"pmulh", mul_shift_right(wild_i16x_, wild_i16x_, 16)}, {"pmulh", mul_shift_right(wild_u16x_, wild_u16x_, 16)}, @@ -656,7 +655,6 @@ void CodeGen_X86::visit(const Call *op) { {"saturating_narrow", i8_sat(wild_i16x_)}, {"saturating_narrow", u8_sat(wild_i16x_)}, }; - // clang-format on vector matches; for (const auto &pattern : patterns) { @@ -668,52 +666,49 @@ void CodeGen_X86::visit(const Call *op) { } } - // clang-format off - static const Pattern reinterpret_patterns[] = { - {"saturating_narrow", i16_sat(wild_u32x_)}, - {"saturating_narrow", u16_sat(wild_u32x_)}, - {"saturating_narrow", i8_sat(wild_u16x_)}, - {"saturating_narrow", u8_sat(wild_u16x_)}, - }; - // clang-format on + if (op->is_intrinsic(Call::saturating_cast)) { - // Search for saturating casts where the inner value can be - // reinterpreted to signed, so that we can use existing - // saturating_narrow instructions. - // TODO: should use lossless_cast once it is fixed. - for (const auto &pattern : reinterpret_patterns) { - if (expr_match(pattern.pattern, op, matches)) { - const Expr &expr = matches[0]; - const Type &t = expr.type(); - // TODO(8212): might want to keep track of scope of bounds information. - const ConstantInterval ibounds = constant_integer_bounds(expr); - const Type reint_type = t.with_code(halide_type_int); - // If the signed type can represent the maximum value unsigned value, - // we can safely reinterpret this unsigned expression as signed. - if (reint_type.can_represent(ibounds)) { - // Can safely reinterpret to signed integer. - matches[0] = cast(reint_type, matches[0]); - value = call_overloaded_intrin(op->type, pattern.intrin, matches); - if (value) { - return; + static const Pattern reinterpret_patterns[] = { + {"saturating_narrow", i16_sat(wild_u32x_)}, + {"saturating_narrow", u16_sat(wild_u32x_)}, + {"saturating_narrow", i8_sat(wild_u16x_)}, + {"saturating_narrow", u8_sat(wild_u16x_)}, + }; + + // Search for saturating casts where the inner value can be + // reinterpreted to signed, so that we can use existing + // saturating_narrow instructions. + for (const auto &pattern : reinterpret_patterns) { + if (expr_match(pattern.pattern, op, matches)) { + const Type signed_type = matches[0].type().with_code(halide_type_int); + Expr e = lossless_cast(signed_type, matches[0]); + if (e.defined()) { + // Can safely reinterpret to signed integer. + matches[0] = e; + value = call_overloaded_intrin(op->type, pattern.intrin, matches); + if (value) { + return; + } } + // No reinterpret patterns match the same input, so stop matching. + break; } - // No reinterpret patterns match the same input, so stop matching. - break; } - } - static const vector> cast_rewrites = { - // Some double-narrowing saturating casts can be better expressed as - // combinations of single-narrowing saturating casts. - {u8_sat(wild_i32x_), u8_sat(i16_sat(wild_i32x_))}, - {i8_sat(wild_i32x_), i8_sat(i16_sat(wild_i32x_))}, - }; - for (const auto &i : cast_rewrites) { - if (expr_match(i.first, op, matches)) { - Expr replacement = substitute("*", matches[0], with_lanes(i.second, op->type.lanes())); - value = codegen(replacement); - return; + static const vector> cast_rewrites = { + // Some double-narrowing saturating casts can be better expressed as + // combinations of single-narrowing saturating casts. + {u8_sat(wild_i32x_), u8_sat(i16_sat(wild_i32x_))}, + {i8_sat(wild_i32x_), i8_sat(i16_sat(wild_i32x_))}, + {i8_sat(wild_u32x_), i8_sat(i16_sat(wild_u32x_))}, + }; + + for (const auto &i : cast_rewrites) { + if (expr_match(i.first, op, matches)) { + Expr replacement = substitute("*", matches[0], with_lanes(i.second, op->type.lanes())); + value = codegen(replacement); + return; + } } } diff --git a/test/correctness/simd_op_check_x86.cpp b/test/correctness/simd_op_check_x86.cpp index 4a81dfbdf926..e8f61544fe7c 100644 --- a/test/correctness/simd_op_check_x86.cpp +++ b/test/correctness/simd_op_check_x86.cpp @@ -234,6 +234,12 @@ class SimdOpCheckX86 : public SimdOpCheckTest { check(std::string("packssdw") + check_suffix, 8 * w, u8_sat(i32_1)); check(std::string("packssdw") + check_suffix, 8 * w, i8_sat(i32_1)); + // A uint without the top bit set can be reinterpreted as an int + // so that packssdw can be used. + check(std::string("packssdw") + check_suffix, 4 * w, i16_sat(u32_1 >> 1)); + check(std::string("packssdw") + check_suffix, 8 * w, i8_sat(u32_1 >> 1)); + check(std::string("packsswb") + check_suffix, 8 * w, i8_sat(u16_1 >> 1)); + // Sum-of-absolute-difference ops { const int f = 8; // reduction factor. From c11d977a3f2c32eaa276b3d85243e4bcb7c23c3c Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Fri, 20 Dec 2024 11:14:02 -0800 Subject: [PATCH 08/16] Don't cache mutations of Exprs that have only one reference to them (#8518) Don't cache mutations of Exprs that have only one reference to them. This speeds up lowering of local laplacian by about 5% --- src/IRMutator.cpp | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/src/IRMutator.cpp b/src/IRMutator.cpp index 7e3e278bc75b..f0b861f1d5c0 100644 --- a/src/IRMutator.cpp +++ b/src/IRMutator.cpp @@ -381,21 +381,33 @@ Stmt IRMutator::visit(const HoistedStorage *op) { } Stmt IRGraphMutator::mutate(const Stmt &s) { - auto p = stmt_replacements.emplace(s, Stmt()); - if (p.second) { - // N.B: Inserting into a map (as the recursive mutate call - // does), does not invalidate existing iterators. - p.first->second = IRMutator::mutate(s); + if (s.is_sole_reference()) { + // There's no point in caching mutations of this Stmt. We can never + // possibly see it again, and it can't be in the cache already if this + // is the sole reference. Doing this here and in the Expr mutate method + // below speeds up lowering by about 5% + return IRMutator::mutate(s); + } else { + auto p = stmt_replacements.emplace(s, Stmt()); + if (p.second) { + // N.B: Inserting into a map (as the recursive mutate call + // does), does not invalidate existing iterators. + p.first->second = IRMutator::mutate(s); + } + return p.first->second; } - return p.first->second; } Expr IRGraphMutator::mutate(const Expr &e) { - auto p = expr_replacements.emplace(e, Expr()); - if (p.second) { - p.first->second = IRMutator::mutate(e); + if (e.is_sole_reference()) { + return IRMutator::mutate(e); + } else { + auto p = expr_replacements.emplace(e, Expr()); + if (p.second) { + p.first->second = IRMutator::mutate(e); + } + return p.first->second; } - return p.first->second; } std::pair, bool> IRMutator::mutate_with_changes(const std::vector &old_exprs) { From 9c3e6de627eb1c0c74a21b61485f35e546bea86e Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 23 Dec 2024 12:48:06 -0800 Subject: [PATCH 09/16] Fix #8534 (#8535) * Fix #8534 This needs to take src by value, because the helper function mutates ones of its members in-place. * Add comment --- src/runtime/HalideBuffer.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/runtime/HalideBuffer.h b/src/runtime/HalideBuffer.h index f854cf43c24f..9741c0278e3d 100644 --- a/src/runtime/HalideBuffer.h +++ b/src/runtime/HalideBuffer.h @@ -1991,9 +1991,11 @@ class Buffer { /** Make a buffer with the same shape and memory nesting order as * another buffer. It may have a different type. */ template - static Buffer make_with_shape_of(const Buffer &src, + static Buffer make_with_shape_of(Buffer src, void *(*allocate_fn)(size_t) = nullptr, void (*deallocate_fn)(void *) = nullptr) { + // Note that src is taken by value because its dims are mutated + // in-place by the helper. Do not change to taking it by reference. static_assert(Dims == D2 || Dims == AnyDims); const halide_type_t dst_type = T_is_void ? src.type() : halide_type_of::type>(); return Buffer<>::make_with_shape_of_helper(dst_type, src.dimensions(), src.buf.dim, From b5a5ca2768389ee94f448bf8f48f140df2e0bdeb Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 23 Dec 2024 12:49:24 -0800 Subject: [PATCH 10/16] Remove llvm version check from Makefile (#8533) It's redundant with the one in LLVM_Headers.h, so all it does is break every time we get a new LLVM version. --- Makefile | 169 ++++++++----------------------------------------------- 1 file changed, 23 insertions(+), 146 deletions(-) diff --git a/Makefile b/Makefile index 4d4cde71331e..5d242a2e4aa5 100644 --- a/Makefile +++ b/Makefile @@ -1092,83 +1092,83 @@ RUNTIME_CXX_FLAGS = \ -Wno-sync-alignment \ -isystem $(ROOT_DIR)/dependencies/vulkan/include -$(BUILD_DIR)/initmod.windows_%_x86_32.ll: $(SRC_DIR)/runtime/windows_%_x86.cpp $(BUILD_DIR)/clang_ok +$(BUILD_DIR)/initmod.windows_%_x86_32.ll: $(SRC_DIR)/runtime/windows_%_x86.cpp @mkdir -p $(@D) $(CLANG) $(CXX_WARNING_FLAGS) $(RUNTIME_CXX_FLAGS) -m32 -target $(RUNTIME_TRIPLE_WIN_X86_32) -DCOMPILING_HALIDE_RUNTIME -DBITS_32 -emit-llvm -S $(SRC_DIR)/runtime/windows_$*_x86.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.windows_$*_x86_32.d -$(BUILD_DIR)/initmod.windows_%_x86_64.ll: $(SRC_DIR)/runtime/windows_%_x86.cpp $(BUILD_DIR)/clang_ok +$(BUILD_DIR)/initmod.windows_%_x86_64.ll: $(SRC_DIR)/runtime/windows_%_x86.cpp @mkdir -p $(@D) $(CLANG) $(CXX_WARNING_FLAGS) $(RUNTIME_CXX_FLAGS) -m64 -target $(RUNTIME_TRIPLE_WIN_X86_64) -DCOMPILING_HALIDE_RUNTIME -DBITS_64 -emit-llvm -S $(SRC_DIR)/runtime/windows_$*_x86.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.windows_$*_x86_64.d -$(BUILD_DIR)/initmod.windows_%_arm_32.ll: $(SRC_DIR)/runtime/windows_%_arm.cpp $(BUILD_DIR)/clang_ok +$(BUILD_DIR)/initmod.windows_%_arm_32.ll: $(SRC_DIR)/runtime/windows_%_arm.cpp @mkdir -p $(@D) $(CLANG) $(CXX_WARNING_FLAGS) $(RUNTIME_CXX_FLAGS) -m32 -target $(RUNTIME_TRIPLE_WIN_ARM_32) -DCOMPILING_HALIDE_RUNTIME -DBITS_32 -emit-llvm -S $(SRC_DIR)/runtime/windows_$*_arm.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.windows_$*_arm_32.d -$(BUILD_DIR)/initmod.windows_%_arm_64.ll: $(SRC_DIR)/runtime/windows_%_arm.cpp $(BUILD_DIR)/clang_ok +$(BUILD_DIR)/initmod.windows_%_arm_64.ll: $(SRC_DIR)/runtime/windows_%_arm.cpp @mkdir -p $(@D) $(CLANG) $(CXX_WARNING_FLAGS) $(RUNTIME_CXX_FLAGS) -m64 -target $(RUNTIME_TRIPLE_WIN_ARM_64) -DCOMPILING_HALIDE_RUNTIME -DBITS_64 -emit-llvm -S $(SRC_DIR)/runtime/windows_$*_arm.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.windows_$*_arm_64.d -$(BUILD_DIR)/initmod.windows_%_32.ll: $(SRC_DIR)/runtime/windows_%.cpp $(BUILD_DIR)/clang_ok +$(BUILD_DIR)/initmod.windows_%_32.ll: $(SRC_DIR)/runtime/windows_%.cpp @mkdir -p $(@D) $(CLANG) $(CXX_WARNING_FLAGS) $(RUNTIME_CXX_FLAGS) -m32 -target $(RUNTIME_TRIPLE_WIN_X86_32) -DCOMPILING_HALIDE_RUNTIME -DBITS_32 -emit-llvm -S $(SRC_DIR)/runtime/windows_$*.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.windows_$*_32.d -$(BUILD_DIR)/initmod.windows_%_64.ll: $(SRC_DIR)/runtime/windows_%.cpp $(BUILD_DIR)/clang_ok +$(BUILD_DIR)/initmod.windows_%_64.ll: $(SRC_DIR)/runtime/windows_%.cpp @mkdir -p $(@D) $(CLANG) $(CXX_WARNING_FLAGS) $(RUNTIME_CXX_FLAGS) -m64 -target $(RUNTIME_TRIPLE_WIN_GENERIC_64) -fshort-wchar -DCOMPILING_HALIDE_RUNTIME -DBITS_64 -emit-llvm -S $(SRC_DIR)/runtime/windows_$*.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.windows_$*_64.d -$(BUILD_DIR)/initmod.webgpu_%_32.ll: $(SRC_DIR)/runtime/webgpu_%.cpp $(BUILD_DIR)/clang_ok +$(BUILD_DIR)/initmod.webgpu_%_32.ll: $(SRC_DIR)/runtime/webgpu_%.cpp @mkdir -p $(@D) $(CLANG) $(CXX_WARNING_FLAGS) $(RUNTIME_CXX_FLAGS) -m32 -target $(RUNTIME_TRIPLE_WEBGPU_32) -DCOMPILING_HALIDE_RUNTIME -DBITS_32 -emit-llvm -S $(SRC_DIR)/runtime/webgpu_$*.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.webgpu_$*_32.d -$(BUILD_DIR)/initmod.webgpu_%_64.ll: $(SRC_DIR)/runtime/webgpu_%.cpp $(BUILD_DIR)/clang_ok +$(BUILD_DIR)/initmod.webgpu_%_64.ll: $(SRC_DIR)/runtime/webgpu_%.cpp @mkdir -p $(@D) $(CLANG) $(CXX_WARNING_FLAGS) $(RUNTIME_CXX_FLAGS) -m64 -target $(RUNTIME_TRIPLE_WEBGPU_64) -DCOMPILING_HALIDE_RUNTIME -DBITS_64 -emit-llvm -S $(SRC_DIR)/runtime/webgpu_$*.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.webgpu_$*_64.d -$(BUILD_DIR)/initmod.webgpu_%_32_debug.ll: $(SRC_DIR)/runtime/webgpu_%.cpp $(BUILD_DIR)/clang_ok +$(BUILD_DIR)/initmod.webgpu_%_32_debug.ll: $(SRC_DIR)/runtime/webgpu_%.cpp @mkdir -p $(@D) $(CLANG) $(CXX_WARNING_FLAGS) -g -DDEBUG_RUNTIME $(RUNTIME_CXX_FLAGS) -m32 -target $(RUNTIME_TRIPLE_WEBGPU_32) -DCOMPILING_HALIDE_RUNTIME -DBITS_32 -emit-llvm -S $(SRC_DIR)/runtime/webgpu_$*.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.webgpu_$*_32_debug.d -$(BUILD_DIR)/initmod.webgpu_%_64_debug.ll: $(SRC_DIR)/runtime/webgpu_%.cpp $(BUILD_DIR)/clang_ok +$(BUILD_DIR)/initmod.webgpu_%_64_debug.ll: $(SRC_DIR)/runtime/webgpu_%.cpp @mkdir -p $(@D) $(CLANG) $(CXX_WARNING_FLAGS) -g -DDEBUG_RUNTIME $(RUNTIME_CXX_FLAGS) -m64 -target $(RUNTIME_TRIPLE_WEBGPU_64) -DCOMPILING_HALIDE_RUNTIME -DBITS_64 -emit-llvm -S $(SRC_DIR)/runtime/webgpu_$*.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.webgpu_$*_64_debug.d -$(BUILD_DIR)/initmod.%_64.ll: $(SRC_DIR)/runtime/%.cpp $(BUILD_DIR)/clang_ok +$(BUILD_DIR)/initmod.%_64.ll: $(SRC_DIR)/runtime/%.cpp @mkdir -p $(@D) $(CLANG) $(CXX_WARNING_FLAGS) $(RUNTIME_CXX_FLAGS) -fpic -m64 -target $(RUNTIME_TRIPLE_64) -DCOMPILING_HALIDE_RUNTIME -DBITS_64 -emit-llvm -S $(SRC_DIR)/runtime/$*.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.$*_64.d -$(BUILD_DIR)/initmod.%_32.ll: $(SRC_DIR)/runtime/%.cpp $(BUILD_DIR)/clang_ok +$(BUILD_DIR)/initmod.%_32.ll: $(SRC_DIR)/runtime/%.cpp @mkdir -p $(@D) $(CLANG) $(CXX_WARNING_FLAGS) $(RUNTIME_CXX_FLAGS) -fpic -m32 -target $(RUNTIME_TRIPLE_32) -DCOMPILING_HALIDE_RUNTIME -DBITS_32 -emit-llvm -S $(SRC_DIR)/runtime/$*.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.$*_32.d -$(BUILD_DIR)/initmod.windows_%_x86_32_debug.ll: $(SRC_DIR)/runtime/windows_%_x86.cpp $(BUILD_DIR)/clang_ok +$(BUILD_DIR)/initmod.windows_%_x86_32_debug.ll: $(SRC_DIR)/runtime/windows_%_x86.cpp @mkdir -p $(@D) $(CLANG) $(CXX_WARNING_FLAGS) -g -DDEBUG_RUNTIME $(RUNTIME_CXX_FLAGS) -m32 -target $(RUNTIME_TRIPLE_WIN_X86_32) -DCOMPILING_HALIDE_RUNTIME -DBITS_32 -emit-llvm -S $(SRC_DIR)/runtime/windows_$*_x86.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.windows_$*_x86_32_debug.d -$(BUILD_DIR)/initmod.windows_%_x86_64_debug.ll: $(SRC_DIR)/runtime/windows_%_x86.cpp $(BUILD_DIR)/clang_ok +$(BUILD_DIR)/initmod.windows_%_x86_64_debug.ll: $(SRC_DIR)/runtime/windows_%_x86.cpp @mkdir -p $(@D) $(CLANG) $(CXX_WARNING_FLAGS) -g -DDEBUG_RUNTIME $(RUNTIME_CXX_FLAGS) -m64 -target $(RUNTIME_TRIPLE_WIN_X86_64) -DCOMPILING_HALIDE_RUNTIME -DBITS_64 -emit-llvm -S $(SRC_DIR)/runtime/windows_$*_x86.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.windows_$*_x86_64_debug.d -$(BUILD_DIR)/initmod.windows_%_arm_32_debug.ll: $(SRC_DIR)/runtime/windows_%_arm.cpp $(BUILD_DIR)/clang_ok +$(BUILD_DIR)/initmod.windows_%_arm_32_debug.ll: $(SRC_DIR)/runtime/windows_%_arm.cpp @mkdir -p $(@D) $(CLANG) $(CXX_WARNING_FLAGS) -g -DDEBUG_RUNTIME $(RUNTIME_CXX_FLAGS) -m32 -target $(RUNTIME_TRIPLE_WIN_ARM_32) -DCOMPILING_HALIDE_RUNTIME -DBITS_32 -emit-llvm -S $(SRC_DIR)/runtime/windows_$*_arm.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.windows_$*_arm_32_debug.d -$(BUILD_DIR)/initmod.windows_%_arm_64_debug.ll: $(SRC_DIR)/runtime/windows_%_arm.cpp $(BUILD_DIR)/clang_ok +$(BUILD_DIR)/initmod.windows_%_arm_64_debug.ll: $(SRC_DIR)/runtime/windows_%_arm.cpp @mkdir -p $(@D) $(CLANG) $(CXX_WARNING_FLAGS) -g -DDEBUG_RUNTIME $(RUNTIME_CXX_FLAGS) -m64 -target $(RUNTIME_TRIPLE_WIN_ARM_64) -DCOMPILING_HALIDE_RUNTIME -DBITS_64 -emit-llvm -S $(SRC_DIR)/runtime/windows_$*_arm.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.windows_$*_arm_64_debug.d -$(BUILD_DIR)/initmod.windows_%_64_debug.ll: $(SRC_DIR)/runtime/windows_%.cpp $(BUILD_DIR)/clang_ok +$(BUILD_DIR)/initmod.windows_%_64_debug.ll: $(SRC_DIR)/runtime/windows_%.cpp @mkdir -p $(@D) $(CLANG) $(CXX_WARNING_FLAGS) -g -DDEBUG_RUNTIME $(RUNTIME_CXX_FLAGS) -m64 -target $(RUNTIME_TRIPLE_WIN_GENERIC_64) -DCOMPILING_HALIDE_RUNTIME -DBITS_64 -emit-llvm -S $(SRC_DIR)/runtime/windows_$*.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.windows_$*_64_debug.d -$(BUILD_DIR)/initmod.%_64_debug.ll: $(SRC_DIR)/runtime/%.cpp $(BUILD_DIR)/clang_ok +$(BUILD_DIR)/initmod.%_64_debug.ll: $(SRC_DIR)/runtime/%.cpp @mkdir -p $(@D) $(CLANG) $(CXX_WARNING_FLAGS) -g -DDEBUG_RUNTIME $(RUNTIME_CXX_FLAGS) -fpic -m64 -target $(RUNTIME_TRIPLE_64) -DCOMPILING_HALIDE_RUNTIME -DBITS_64 -emit-llvm -S $(SRC_DIR)/runtime/$*.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.$*_64_debug.d -$(BUILD_DIR)/initmod.windows_%_32_debug.ll: $(SRC_DIR)/runtime/windows_%.cpp $(BUILD_DIR)/clang_ok +$(BUILD_DIR)/initmod.windows_%_32_debug.ll: $(SRC_DIR)/runtime/windows_%.cpp @mkdir -p $(@D) $(CLANG) $(CXX_WARNING_FLAGS) -g -DDEBUG_RUNTIME $(RUNTIME_CXX_FLAGS) -m32 -target $(RUNTIME_TRIPLE_WIN_X86_32) -DCOMPILING_HALIDE_RUNTIME -DBITS_32 -emit-llvm -S $(SRC_DIR)/runtime/windows_$*.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.windows_$*_32_debug.d -$(BUILD_DIR)/initmod.%_32_debug.ll: $(SRC_DIR)/runtime/%.cpp $(BUILD_DIR)/clang_ok +$(BUILD_DIR)/initmod.%_32_debug.ll: $(SRC_DIR)/runtime/%.cpp @mkdir -p $(@D) $(CLANG) $(CXX_WARNING_FLAGS) -g -DDEBUG_RUNTIME -O3 $(RUNTIME_CXX_FLAGS) -fpic -m32 -target $(RUNTIME_TRIPLE_32) -DCOMPILING_HALIDE_RUNTIME -DBITS_32 -emit-llvm -S $(SRC_DIR)/runtime/$*.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.$*_32_debug.d @@ -1176,7 +1176,7 @@ $(BUILD_DIR)/initmod.%_ll.ll: $(SRC_DIR)/runtime/%.ll @mkdir -p $(@D) cp $(SRC_DIR)/runtime/$*.ll $(BUILD_DIR)/initmod.$*_ll.ll -$(BUILD_DIR)/initmod.%.bc: $(BUILD_DIR)/initmod.%.ll $(BUILD_DIR)/llvm_ok +$(BUILD_DIR)/initmod.%.bc: $(BUILD_DIR)/initmod.%.ll $(LLVM_AS) $(BUILD_DIR)/initmod.$*.ll -o $(BUILD_DIR)/initmod.$*.bc $(BUILD_DIR)/initmod.%.cpp: $(BIN_DIR)/binary2cpp $(BUILD_DIR)/initmod.%.bc @@ -1218,11 +1218,11 @@ $(BUILD_DIR)/c_template.%.o: $(BUILD_DIR)/c_template.%.cpp $(BUILD_DIR)/html_template.%.o: $(BUILD_DIR)/html_template.%.cpp $(CXX) -c $< -o $@ -MMD -MP -MF $(BUILD_DIR)/$*.d -MT $(BUILD_DIR)/$*.o -$(BUILD_DIR)/%.o: $(SRC_DIR)/%.cpp $(BUILD_DIR)/llvm_ok +$(BUILD_DIR)/%.o: $(SRC_DIR)/%.cpp @mkdir -p $(@D) $(CXX) $(CXX_FLAGS) -c $< -o $@ -MMD -MP -MF $(BUILD_DIR)/$*.d -MT $(BUILD_DIR)/$*.o -$(BUILD_DIR)/Simplify_%.o: $(SRC_DIR)/Simplify_%.cpp $(SRC_DIR)/Simplify_Internal.h $(BUILD_DIR)/llvm_ok +$(BUILD_DIR)/Simplify_%.o: $(SRC_DIR)/Simplify_%.cpp $(SRC_DIR)/Simplify_Internal.h @mkdir -p $(@D) $(CXX) $(CXX_FLAGS) -c $< -o $@ -MMD -MP -MF $(BUILD_DIR)/Simplify_$*.d -MT $@ @@ -2212,129 +2212,6 @@ benchmark_apps: $(BENCHMARK_APPS) || exit 1 ; \ done -# It's just for compiling the runtime, so earlier clangs *might* work, -# but best to peg it to the minimum llvm version. -ifneq (,$(findstring clang version 3.7,$(CLANG_VERSION))) -CLANG_OK=yes -endif - -ifneq (,$(findstring clang version 3.8,$(CLANG_VERSION))) -CLANG_OK=yes -endif - -ifneq (,$(findstring clang version 4.0,$(CLANG_VERSION))) -CLANG_OK=yes -endif - -ifneq (,$(findstring clang version 5.0,$(CLANG_VERSION))) -CLANG_OK=yes -endif - -ifneq (,$(findstring clang version 6.0,$(CLANG_VERSION))) -CLANG_OK=yes -endif - -ifneq (,$(findstring clang version 7.0,$(CLANG_VERSION))) -CLANG_OK=yes -endif - -ifneq (,$(findstring clang version 7.1,$(CLANG_VERSION))) -CLANG_OK=yes -endif - -ifneq (,$(findstring clang version 8.0,$(CLANG_VERSION))) -CLANG_OK=yes -endif - -ifneq (,$(findstring clang version 9.0,$(CLANG_VERSION))) -CLANG_OK=yes -endif - -ifneq (,$(findstring clang version 10.0,$(CLANG_VERSION))) -CLANG_OK=yes -endif - -ifneq (,$(findstring clang version 11.0,$(CLANG_VERSION))) -CLANG_OK=yes -endif - -ifneq (,$(findstring clang version 11.1,$(CLANG_VERSION))) -CLANG_OK=yes -endif - -ifneq (,$(findstring clang version 12.0,$(CLANG_VERSION))) -CLANG_OK=yes -endif - -ifneq (,$(findstring clang version 13.0,$(CLANG_VERSION))) -CLANG_OK=yes -endif - -ifneq (,$(findstring clang version 14.0,$(CLANG_VERSION))) -CLANG_OK=yes -endif - -ifneq (,$(findstring clang version 15.0,$(CLANG_VERSION))) -CLANG_OK=yes -endif - -ifneq (,$(findstring clang version 16.0,$(CLANG_VERSION))) -CLANG_OK=yes -endif - -ifneq (,$(findstring clang version 17.0,$(CLANG_VERSION))) -CLANG_OK=yes -endif - -ifneq (,$(findstring clang version 18.0,$(CLANG_VERSION))) -CLANG_OK=yes -endif - -ifneq (,$(findstring clang version 19.0,$(CLANG_VERSION))) -CLANG_OK=yes -endif - -ifneq (,$(findstring clang version 20.0,$(CLANG_VERSION))) -CLANG_OK=yes -endif - -ifneq (,$(findstring Apple LLVM version 5.0,$(CLANG_VERSION))) -CLANG_OK=yes -endif - -ifneq ($(CLANG_OK), ) -$(BUILD_DIR)/clang_ok: - @echo "Found a new enough version of clang" - mkdir -p $(BUILD_DIR) - touch $(BUILD_DIR)/clang_ok -else -$(BUILD_DIR)/clang_ok: - @echo "Can't find clang or version of clang too old (we need 3.7 or greater):" - @echo "You can override this check by setting CLANG_OK=y" - echo '$(CLANG_VERSION)' - echo $(findstring version 3,$(CLANG_VERSION)) - echo $(findstring version 3.0,$(CLANG_VERSION)) - $(CLANG) --version - @exit 1 -endif - -ifneq (,$(findstring $(LLVM_VERSION_TIMES_10), 160 170 180 190 200)) -LLVM_OK=yes -endif - -ifneq ($(LLVM_OK), ) -$(BUILD_DIR)/llvm_ok: $(BUILD_DIR)/rtti_ok - @echo "Found a new enough version of llvm" - mkdir -p $(BUILD_DIR) - touch $(BUILD_DIR)/llvm_ok -else -$(BUILD_DIR)/llvm_ok: - @echo "Can't find llvm or version of llvm too old (we need 9.0 or greater):" - @echo "You can override this check by setting LLVM_OK=y" - $(LLVM_CONFIG) --version - @exit 1 -endif - ifneq ($(WITH_RTTI), ) ifneq ($(LLVM_HAS_NO_RTTI), ) else From efb77e5dd92061e99362ee9b461f2a55bc81c777 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Fri, 27 Dec 2024 10:23:08 -0800 Subject: [PATCH 11/16] Skip fast exp/log/pow/sin/cosine tests without sse 4.1 (#8541) Fixes #8536 --- src/IROperator.h | 14 +++++++++----- test/performance/fast_pow.cpp | 6 ++++++ test/performance/fast_sine_cosine.cpp | 7 +++++++ 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/IROperator.h b/src/IROperator.h index 0db5606f011c..2b0ce6d97563 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -970,8 +970,9 @@ Expr pow(Expr x, Expr y); * mantissa. Vectorizes cleanly. */ Expr erf(const Expr &x); -/** Fast vectorizable approximation to some trigonometric functions for Float(32). - * Absolute approximation error is less than 1e-5. */ +/** Fast vectorizable approximation to some trigonometric functions for + * Float(32). Absolute approximation error is less than 1e-5. Slow on x86 if + * you don't have at least sse 4.1. */ // @{ Expr fast_sin(const Expr &x); Expr fast_cos(const Expr &x); @@ -979,19 +980,22 @@ Expr fast_cos(const Expr &x); /** Fast approximate cleanly vectorizable log for Float(32). Returns * nonsense for x <= 0.0f. Accurate up to the last 5 bits of the - * mantissa. Vectorizes cleanly. */ + * mantissa. Vectorizes cleanly. Slow on x86 if you don't + * have at least sse 4.1. */ Expr fast_log(const Expr &x); /** Fast approximate cleanly vectorizable exp for Float(32). Returns * nonsense for inputs that would overflow or underflow. Typically * accurate up to the last 5 bits of the mantissa. Gets worse when - * approaching overflow. Vectorizes cleanly. */ + * approaching overflow. Vectorizes cleanly. Slow on x86 if you don't + * have at least sse 4.1. */ Expr fast_exp(const Expr &x); /** Fast approximate cleanly vectorizable pow for Float(32). Returns * nonsense for x < 0.0f. Accurate up to the last 5 bits of the * mantissa for typical exponents. Gets worse when approaching - * overflow. Vectorizes cleanly. */ + * overflow. Vectorizes cleanly. Slow on x86 if you don't + * have at least sse 4.1. */ Expr fast_pow(Expr x, Expr y); /** Fast approximate inverse for Float(32). Corresponds to the rcpps diff --git a/test/performance/fast_pow.cpp b/test/performance/fast_pow.cpp index 801d5f3133f2..24cea2c32418 100644 --- a/test/performance/fast_pow.cpp +++ b/test/performance/fast_pow.cpp @@ -20,6 +20,12 @@ int main(int argc, char **argv) { printf("HL_TARGET is: %s\n", hl_target.to_string().c_str()); printf("HL_JIT_TARGET is: %s\n", hl_jit_target.to_string().c_str()); + if (hl_jit_target.arch == Target::X86 && + !hl_jit_target.has_feature(Target::SSE41)) { + printf("[SKIP] These intrinsics are known to be slow on x86 without sse 4.1.\n"); + return 0; + } + if (hl_jit_target.arch == Target::WebAssembly) { printf("[SKIP] Performance tests are meaningless and/or misleading under WebAssembly interpreter.\n"); return 0; diff --git a/test/performance/fast_sine_cosine.cpp b/test/performance/fast_sine_cosine.cpp index dc8e7a360550..81f79f337c32 100644 --- a/test/performance/fast_sine_cosine.cpp +++ b/test/performance/fast_sine_cosine.cpp @@ -10,6 +10,13 @@ using namespace Halide::Tools; int main(int argc, char **argv) { Target target = get_jit_target_from_environment(); + + if (target.arch == Target::X86 && + !target.has_feature(Target::SSE41)) { + printf("[SKIP] These intrinsics are known to be slow on x86 without sse 4.1.\n"); + return 0; + } + if (target.arch == Target::WebAssembly) { printf("[SKIP] Performance tests are meaningless and/or misleading under WebAssembly interpreter.\n"); return 0; From c2d5ea309080125352a184878005286a1375ba06 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Fri, 27 Dec 2024 10:23:18 -0800 Subject: [PATCH 12/16] Hopefully fix flaky mullapudi reorder test (#8542) * Hopefully fix flaky test * clang-format --- test/autoschedulers/mullapudi2016/reorder.cpp | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/test/autoschedulers/mullapudi2016/reorder.cpp b/test/autoschedulers/mullapudi2016/reorder.cpp index 37476a0ff0dc..942da121b8b1 100644 --- a/test/autoschedulers/mullapudi2016/reorder.cpp +++ b/test/autoschedulers/mullapudi2016/reorder.cpp @@ -7,10 +7,16 @@ using namespace Halide::Tools; double run_test_1(bool auto_schedule) { Var x("x"), y("y"), dx("dx"), dy("dy"), c("c"); + int W = 1024; + int H = 1920; + int search_area = 7; + + Buffer im(2048); + im.fill(17); + Func f("f"); - f(x, y, dx, dy) = x + y + dx + dy; + f(x, y, dx, dy) = im(x) + im(y + 1) + im(dx + search_area / 2) + im(dy + search_area / 2); - int search_area = 7; RDom dom(-search_area / 2, search_area, -search_area / 2, search_area, "dom"); // If 'f' is inlined into 'r', the only storage layout that the auto scheduler @@ -23,23 +29,20 @@ double run_test_1(bool auto_schedule) { if (auto_schedule) { // Provide estimates on the pipeline output - r.set_estimates({{0, 1024}, {0, 1024}, {0, 3}}); + r.set_estimates({{0, W}, {0, H}, {0, 3}}); // Auto-schedule the pipeline p.apply_autoscheduler(target, {"Mullapudi2016"}); } else { - /* + Var par; r.update(0).fuse(c, y, par).parallel(par).reorder(x, dom.x, dom.y).vectorize(x, 4); - r.fuse(c, y, par).parallel(par).vectorize(x, 4); */ - - // The sequential schedule in this case seems to perform best which is - // odd have to investigate this further. + r.fuse(c, y, par).parallel(par).vectorize(x, 4); } // Inspect the schedule (only for debugging)) // r.print_loop_nest(); // Run the schedule - Buffer out(1024, 1024, 3); + Buffer out(W, H, 3); double t = benchmark(3, 10, [&]() { p.realize(out); }); @@ -154,7 +157,7 @@ int main(int argc, char **argv) { double manual_time = run_test_1(false); double auto_time = run_test_1(true); - const double slowdown_factor = 15.0; // TODO: whoa + const double slowdown_factor = 2.0; if (!get_jit_target_from_environment().has_gpu_feature() && auto_time > manual_time * slowdown_factor) { std::cerr << "Autoscheduler time (1) is slower than expected:\n" << "======================\n" From 097aee94e235afa70910d212717cd94eb6faa241 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Fri, 27 Dec 2024 10:23:40 -0800 Subject: [PATCH 13/16] Use a consistent idiom for visit_let (#8540) visit_let in the codebase uses a wide variety of template names, argument names, and ways of getting the body type. This just picks one and uses it consistently. No functional changes. --- src/BoundSmallAllocations.cpp | 18 ++++++------ src/ClampUnsafeAccesses.cpp | 18 ++++++------ src/Deinterleave.cpp | 28 +++++++++--------- src/EliminateBoolVectors.cpp | 10 +++---- src/FuseGPUThreadLoops.cpp | 10 +++---- src/HexagonOptimize.cpp | 42 +++++++++++++-------------- src/LICM.cpp | 36 +++++++++++------------ src/LowerWarpShuffles.cpp | 10 +++---- src/OptimizeShuffles.cpp | 12 ++++---- src/RemoveUndef.cpp | 18 ++++++------ src/SimplifyCorrelatedDifferences.cpp | 10 +++---- src/TrimNoOps.cpp | 10 +++---- src/UniquifyVariableNames.cpp | 6 ++-- src/VectorizeLoops.cpp | 12 ++++---- 14 files changed, 120 insertions(+), 120 deletions(-) diff --git a/src/BoundSmallAllocations.cpp b/src/BoundSmallAllocations.cpp index 849d224051b4..80f58889448c 100644 --- a/src/BoundSmallAllocations.cpp +++ b/src/BoundSmallAllocations.cpp @@ -17,40 +17,40 @@ class BoundSmallAllocations : public IRMutator { // Track constant bounds Scope scope; - template - Body visit_let(const T *op) { + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { // Visit an entire chain of lets in a single method to conserve stack space. struct Frame { - const T *op; + const LetOrLetStmt *op; ScopedBinding binding; - Frame(const T *op, Scope &scope) + Frame(const LetOrLetStmt *op, Scope &scope) : op(op), binding(scope, op->name, find_constant_bounds(op->value, scope)) { } }; std::vector frames; - Body result; + decltype(op->body) result; do { result = op->body; frames.emplace_back(op, scope); - } while ((op = result.template as())); + } while ((op = result.template as())); result = mutate(result); for (const auto &frame : reverse_view(frames)) { - result = T::make(frame.op->name, frame.op->value, result); + result = LetOrLetStmt::make(frame.op->name, frame.op->value, result); } return result; } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } Expr visit(const Let *op) override { - return visit_let(op); + return visit_let(op); } bool in_thread_loop = false; diff --git a/src/ClampUnsafeAccesses.cpp b/src/ClampUnsafeAccesses.cpp index b3dd9ddc235e..ed6955446196 100644 --- a/src/ClampUnsafeAccesses.cpp +++ b/src/ClampUnsafeAccesses.cpp @@ -42,11 +42,11 @@ struct ClampUnsafeAccesses : IRMutator { } Expr visit(const Let *let) override { - return visit_let(let); + return visit_let(let); } Stmt visit(const LetStmt *let) override { - return visit_let(let); + return visit_let(let); } Expr visit(const Variable *var) override { @@ -80,15 +80,15 @@ struct ClampUnsafeAccesses : IRMutator { } private: - template - Body visit_let(const L *let) { - ScopedBinding binding(let_var_inside_indexing, let->name, false); - Body body = mutate(let->body); + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { + ScopedBinding binding(let_var_inside_indexing, op->name, false); + auto body = mutate(op->body); - ScopedValue s(is_inside_indexing, is_inside_indexing || let_var_inside_indexing.get(let->name)); - Expr value = mutate(let->value); + ScopedValue s(is_inside_indexing, is_inside_indexing || let_var_inside_indexing.get(op->name)); + Expr value = mutate(op->value); - return L::make(let->name, std::move(value), std::move(body)); + return LetOrLetStmt::make(op->name, std::move(value), std::move(body)); } bool bounds_smaller_than_type(const Interval &bounds, Type type) { diff --git a/src/Deinterleave.cpp b/src/Deinterleave.cpp index 54199485fed5..6cd4f490f01c 100644 --- a/src/Deinterleave.cpp +++ b/src/Deinterleave.cpp @@ -465,44 +465,44 @@ class Interleaver : public IRMutator { return Shuffle::make_interleave(exprs); } - template - Body visit_lets(const T *op) { + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { // Visit an entire chain of lets in a single method to conserve stack space. struct Frame { - const T *op; + const LetOrLetStmt *op; Expr new_value; ScopedBinding<> binding; - Frame(const T *op, Expr v, Scope &scope) + Frame(const LetOrLetStmt *op, Expr v, Scope &scope) : op(op), new_value(std::move(v)), binding(new_value.type().is_vector(), scope, op->name) { } }; std::vector frames; - Body result; + decltype(op->body) result; do { result = op->body; frames.emplace_back(op, mutate(op->value), vector_lets); - } while ((op = result.template as())); + } while ((op = result.template as())); result = mutate(result); for (const auto &frame : reverse_view(frames)) { Expr value = std::move(frame.new_value); - result = T::make(frame.op->name, value, result); + result = LetOrLetStmt::make(frame.op->name, value, result); // For vector lets, we may additionally need a let defining the even and odd lanes only if (value.type().is_vector()) { if (value.type().lanes() % 2 == 0) { - result = T::make(frame.op->name + ".even_lanes", extract_even_lanes(value, vector_lets), result); - result = T::make(frame.op->name + ".odd_lanes", extract_odd_lanes(value, vector_lets), result); + result = LetOrLetStmt::make(frame.op->name + ".even_lanes", extract_even_lanes(value, vector_lets), result); + result = LetOrLetStmt::make(frame.op->name + ".odd_lanes", extract_odd_lanes(value, vector_lets), result); } if (value.type().lanes() % 3 == 0) { - result = T::make(frame.op->name + ".lanes_0_of_3", extract_mod3_lanes(value, 0, vector_lets), result); - result = T::make(frame.op->name + ".lanes_1_of_3", extract_mod3_lanes(value, 1, vector_lets), result); - result = T::make(frame.op->name + ".lanes_2_of_3", extract_mod3_lanes(value, 2, vector_lets), result); + result = LetOrLetStmt::make(frame.op->name + ".lanes_0_of_3", extract_mod3_lanes(value, 0, vector_lets), result); + result = LetOrLetStmt::make(frame.op->name + ".lanes_1_of_3", extract_mod3_lanes(value, 1, vector_lets), result); + result = LetOrLetStmt::make(frame.op->name + ".lanes_2_of_3", extract_mod3_lanes(value, 2, vector_lets), result); } } } @@ -511,11 +511,11 @@ class Interleaver : public IRMutator { } Expr visit(const Let *op) override { - return visit_lets(op); + return visit_let(op); } Stmt visit(const LetStmt *op) override { - return visit_lets(op); + return visit_let(op); } Expr visit(const Ramp *op) override { diff --git a/src/EliminateBoolVectors.cpp b/src/EliminateBoolVectors.cpp index 2517ecf70c1c..e4afa8f21569 100644 --- a/src/EliminateBoolVectors.cpp +++ b/src/EliminateBoolVectors.cpp @@ -287,8 +287,8 @@ class EliminateBoolVectors : public IRMutator { return expr; } - template - NodeType visit_let(const LetType *op) { + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { Expr value = mutate(op->value); // We changed the type of the let, we need to replace the @@ -305,17 +305,17 @@ class EliminateBoolVectors : public IRMutator { } if (!value.same_as(op->value) || !body.same_as(op->body)) { - return LetType::make(op->name, value, body); + return LetOrLetStmt::make(op->name, value, body); } else { return op; } } Expr visit(const Let *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } }; diff --git a/src/FuseGPUThreadLoops.cpp b/src/FuseGPUThreadLoops.cpp index 3464a91ddb58..b7da1555526a 100644 --- a/src/FuseGPUThreadLoops.cpp +++ b/src/FuseGPUThreadLoops.cpp @@ -1157,9 +1157,9 @@ class ExtractRegisterAllocations : public IRMutator { op->param, mutate(op->predicate), op->alignment); } - template - ExprOrStmt visit_let(const LetOrLetStmt *op) { - ExprOrStmt body = op->body; + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { + auto body = op->body; body = mutate(op->body); Expr value = mutate(op->value); @@ -1178,11 +1178,11 @@ class ExtractRegisterAllocations : public IRMutator { } Expr visit(const Let *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } Scope register_allocations; diff --git a/src/HexagonOptimize.cpp b/src/HexagonOptimize.cpp index d75231813215..a88f0e6bad57 100644 --- a/src/HexagonOptimize.cpp +++ b/src/HexagonOptimize.cpp @@ -1088,20 +1088,20 @@ class OptimizePatterns : public IRMutator { } } - template - NodeType visit_let(const T *op) { + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { bounds.push(op->name, bounds_of_expr_in_scope(op->value, bounds)); - NodeType node = IRMutator::visit(op); + auto node = IRMutator::visit(op); bounds.pop(op->name); return node; } Expr visit(const Let *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } Expr visit(const Div *op) override { @@ -1599,12 +1599,12 @@ class EliminateInterleaves : public IRMutator { } } - template - NodeType visit_let(const LetType *op) { + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { Expr value = mutate(op->value); string deinterleaved_name; - NodeType body; + decltype(op->body) body; // Other code in this mutator needs to be able to tell the // difference between a Let that yields a deinterleave, and a // let that has a removable deinterleave. Lets that can @@ -1632,10 +1632,10 @@ class EliminateInterleaves : public IRMutator { return op; } else if (body.same_as(op->body)) { // If the body didn't change, we must not have used the deinterleaved value. - return LetType::make(op->name, value, body); + return LetOrLetStmt::make(op->name, value, body); } else { // We need to rewrap the body with new lets. - NodeType result = body; + auto result = body; bool deinterleaved_used = stmt_or_expr_uses_var(result, deinterleaved_name); bool interleaved_used = stmt_or_expr_uses_var(result, op->name); if (deinterleaved_used && interleaved_used) { @@ -1653,14 +1653,14 @@ class EliminateInterleaves : public IRMutator { interleaved = native_interleave(interleaved); } - result = LetType::make(op->name, interleaved, result); - return LetType::make(deinterleaved_name, deinterleaved, result); + result = LetOrLetStmt::make(op->name, interleaved, result); + return LetOrLetStmt::make(deinterleaved_name, deinterleaved, result); } else if (deinterleaved_used) { // Only the deinterleaved value is used, we can eliminate the interleave. - return LetType::make(deinterleaved_name, remove_interleave(value), result); + return LetOrLetStmt::make(deinterleaved_name, remove_interleave(value), result); } else if (interleaved_used) { // Only the original value is used, regenerate the let. - return LetType::make(op->name, value, result); + return LetOrLetStmt::make(op->name, value, result); } else { // The let must have been dead. internal_assert(!stmt_or_expr_uses_var(op->body, op->name)) @@ -1671,7 +1671,7 @@ class EliminateInterleaves : public IRMutator { } Expr visit(const Let *op) override { - Expr expr = visit_let(op); + Expr expr = visit_let(op); // Lift interleaves out of Let expression bodies. const Let *let = expr.as(); @@ -1682,7 +1682,7 @@ class EliminateInterleaves : public IRMutator { } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } Expr visit(const Cast *op) override { @@ -2047,13 +2047,13 @@ class ScatterGatherGenerator : public IRMutator { return IRMutator::visit(op); } - template - NodeType visit_let(const T *op) { + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { // We only care about vector lets. if (op->value.type().is_vector()) { bounds.push(op->name, bounds_of_expr_in_scope(op->value, bounds)); } - NodeType node = IRMutator::visit(op); + auto node = IRMutator::visit(op); if (op->value.type().is_vector()) { bounds.pop(op->name); } @@ -2061,11 +2061,11 @@ class ScatterGatherGenerator : public IRMutator { } Expr visit(const Let *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const Allocate *op) override { diff --git a/src/LICM.cpp b/src/LICM.cpp index 880354d89582..c73fcf5424ab 100644 --- a/src/LICM.cpp +++ b/src/LICM.cpp @@ -112,23 +112,23 @@ class LiftLoopInvariants : public IRMutator { return true; } - template - Body visit_let(const T *op) { + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { // Visit an entire chain of lets in a single method to conserve stack space. struct Frame { - const T *op; + const LetOrLetStmt *op; Expr new_value; ScopedBinding<> binding; - Frame(const T *op, Expr v, Scope<> &scope) + Frame(const LetOrLetStmt *op, Expr v, Scope<> &scope) : op(op), new_value(std::move(v)), binding(scope, op->name) { } }; vector frames; - Body result; + decltype(op->body) result; do { frames.emplace_back(op, mutate(op->value), varying); result = op->body; - } while ((op = result.template as())); + } while ((op = result.template as())); result = mutate(result); @@ -136,7 +136,7 @@ class LiftLoopInvariants : public IRMutator { if (frame.new_value.same_as(frame.op->value) && result.same_as(frame.op->body)) { result = frame.op; } else { - result = T::make(frame.op->name, std::move(frame.new_value), result); + result = LetOrLetStmt::make(frame.op->name, std::move(frame.new_value), result); } } @@ -144,11 +144,11 @@ class LiftLoopInvariants : public IRMutator { } Expr visit(const Let *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const For *op) override { @@ -476,20 +476,20 @@ class GroupLoopInvariants : public IRMutator { return stmt; } - template - Body visit_let(const T *op) { + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { struct Frame { - const T *op; + const LetOrLetStmt *op; Expr new_value; ScopedBinding binding; - Frame(const T *op, Expr v, int depth, Scope &scope) + Frame(const LetOrLetStmt *op, Expr v, int depth, Scope &scope) : op(op), new_value(std::move(v)), binding(scope, op->name, depth) { } }; std::vector frames; - Body result; + decltype(op->body) result; do { result = op->body; @@ -498,7 +498,7 @@ class GroupLoopInvariants : public IRMutator { d = expr_depth(op->value); } frames.emplace_back(op, mutate(op->value), d, var_depth); - } while ((op = result.template as())); + } while ((op = result.template as())); result = mutate(result); @@ -506,7 +506,7 @@ class GroupLoopInvariants : public IRMutator { if (frame.new_value.same_as(frame.op->value) && result.same_as(frame.op->body)) { result = frame.op; } else { - result = T::make(frame.op->name, frame.new_value, result); + result = LetOrLetStmt::make(frame.op->name, frame.new_value, result); } } @@ -514,11 +514,11 @@ class GroupLoopInvariants : public IRMutator { } Expr visit(const Let *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } }; diff --git a/src/LowerWarpShuffles.cpp b/src/LowerWarpShuffles.cpp index fa3dfecb9a1f..2551fe0bffbb 100644 --- a/src/LowerWarpShuffles.cpp +++ b/src/LowerWarpShuffles.cpp @@ -691,10 +691,10 @@ class HoistWarpShufflesFromSingleIfStmt : public IRMutator { } } - template - ExprOrStmt visit_let(const LetOrLetStmt *op) { + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { Expr value = mutate(op->value); - ExprOrStmt body = mutate(op->body); + auto body = mutate(op->body); // If any of the lifted expressions use this, we also need to // lift this. @@ -712,11 +712,11 @@ class HoistWarpShufflesFromSingleIfStmt : public IRMutator { } Expr visit(const Let *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const For *op) override { diff --git a/src/OptimizeShuffles.cpp b/src/OptimizeShuffles.cpp index c986af62474e..0a88d02f0b60 100644 --- a/src/OptimizeShuffles.cpp +++ b/src/OptimizeShuffles.cpp @@ -38,13 +38,13 @@ class OptimizeShuffles : public IRMutator { return IRMutator::visit(op); } - template - NodeType visit_let(const T *op) { + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { // We only care about vector lets. if (op->value.type().is_vector()) { bounds.push(op->name, bounds_of_expr_in_scope(op->value, bounds)); } - NodeType node = IRMutator::visit(op); + auto node = IRMutator::visit(op); if (op->value.type().is_vector()) { bounds.pop(op->name); } @@ -53,12 +53,12 @@ class OptimizeShuffles : public IRMutator { Expr visit(const Let *op) override { lets.emplace_back(op->name, op->value); - Expr expr = visit_let(op); + Expr expr = visit_let(op); lets.pop_back(); return expr; } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } std::set allocations_to_pad; @@ -149,4 +149,4 @@ Stmt optimize_shuffles(Stmt s, int lut_alignment) { } } // namespace Internal -} // namespace Halide \ No newline at end of file +} // namespace Halide diff --git a/src/RemoveUndef.cpp b/src/RemoveUndef.cpp index 9103e401fb33..c738b23311aa 100644 --- a/src/RemoveUndef.cpp +++ b/src/RemoveUndef.cpp @@ -241,25 +241,25 @@ class RemoveUndef : public IRMutator { } } - template - Body visit_let(const T *op) { + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { // Visit an entire chain of lets in a single method to conserve stack space. struct Frame { - const T *op; + const LetOrLetStmt *op; Expr new_value; ScopedBinding<> binding; - Frame(const T *op, Expr v, Scope<> &scope) + Frame(const LetOrLetStmt *op, Expr v, Scope<> &scope) : op(op), new_value(std::move(v)), binding(!new_value.defined(), scope, op->name) { } }; vector frames; - Body result; + decltype(op->body) result; do { frames.emplace_back(op, mutate(op->value), dead_vars); result = op->body; - } while ((op = result.template as())); + } while ((op = result.template as())); result = mutate(result); @@ -272,7 +272,7 @@ class RemoveUndef : public IRMutator { if (frame.new_value.same_as(frame.op->value) && result.same_as(frame.op->body)) { result = frame.op; } else { - result = T::make(frame.op->name, std::move(frame.new_value), result); + result = LetOrLetStmt::make(frame.op->name, std::move(frame.new_value), result); } } } @@ -281,11 +281,11 @@ class RemoveUndef : public IRMutator { } Expr visit(const Let *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const AssertStmt *op) override { diff --git a/src/SimplifyCorrelatedDifferences.cpp b/src/SimplifyCorrelatedDifferences.cpp index 7c6a3bac735c..ba9c2d772fff 100644 --- a/src/SimplifyCorrelatedDifferences.cpp +++ b/src/SimplifyCorrelatedDifferences.cpp @@ -78,8 +78,8 @@ class SimplifyCorrelatedDifferences : public IRMutator { }; vector lets; - template - StmtOrExpr visit_let(const LetStmtOrLet *op) { + template + auto visit_let(const LetStmtOrLet *op) -> decltype(op->body) { // Visit an entire chain of lets in a single method to conserve stack space. struct Frame { const LetStmtOrLet *op; @@ -94,7 +94,7 @@ class SimplifyCorrelatedDifferences : public IRMutator { } }; std::vector frames; - StmtOrExpr result; + decltype(op->body) result; // Note that we must add *everything* that depends on the loop // var to the monotonic scope and the list of lets, even @@ -146,11 +146,11 @@ class SimplifyCorrelatedDifferences : public IRMutator { } Expr visit(const Let *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const For *op) override { diff --git a/src/TrimNoOps.cpp b/src/TrimNoOps.cpp index 1d7232a89b11..e516a17b12bd 100644 --- a/src/TrimNoOps.cpp +++ b/src/TrimNoOps.cpp @@ -309,10 +309,10 @@ class SimplifyUsingBounds : public IRMutator { return visit_cmp(op); } - template - StmtOrExpr visit_let(const LetStmtOrLet *op) { + template + auto visit_let(const LetStmtOrLet *op) -> decltype(op->body) { Expr value = mutate(op->value); - StmtOrExpr body; + decltype(op->body) body; if (value.type() == Int(32) && is_pure(value)) { containing_loops.push_back({op->name, {value, value}}); body = mutate(op->body); @@ -324,11 +324,11 @@ class SimplifyUsingBounds : public IRMutator { } Expr visit(const Let *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } Stmt visit(const For *op) override { diff --git a/src/UniquifyVariableNames.cpp b/src/UniquifyVariableNames.cpp index 2d89e77de787..9dc92c780b3a 100644 --- a/src/UniquifyVariableNames.cpp +++ b/src/UniquifyVariableNames.cpp @@ -130,15 +130,15 @@ class FindFreeVars : public IRVisitor { } } - template - void visit_let(const T *op) { + template + void visit_let(const LetOrLetStmt *op) { vector> frame; decltype(op->body) body; do { op->value.accept(this); frame.emplace_back(scope, op->name); body = op->body; - op = body.template as(); + op = body.template as(); } while (op); body.accept(this); } diff --git a/src/VectorizeLoops.cpp b/src/VectorizeLoops.cpp index 7b16083ca5e2..0c7e762ad262 100644 --- a/src/VectorizeLoops.cpp +++ b/src/VectorizeLoops.cpp @@ -1411,8 +1411,8 @@ class FindVectorizableExprsInAtomicNode : public IRMutator { using IRMutator::visit; - template - const T *visit_let(const T *op) { + template + const LetOrLetStmt *visit_let(const LetOrLetStmt *op) { mutate(op->value); ScopedBinding<> bind_if(poison, poisoned_names, op->name); mutate(op->body); @@ -1498,8 +1498,8 @@ class LiftVectorizableExprsOutOfSingleAtomicNode : public IRMutator { using IRMutator::visit; - template - StmtOrExpr visit_let(const LetStmtOrLet *op) { + template + auto visit_let(const LetStmtOrLet *op) -> decltype(op->body) { if (liftable.count(op->value)) { // Lift it under its current name to avoid having to // rewrite the variables in other lifted exprs. @@ -1512,11 +1512,11 @@ class LiftVectorizableExprsOutOfSingleAtomicNode : public IRMutator { } Stmt visit(const LetStmt *op) override { - return visit_let(op); + return visit_let(op); } Expr visit(const Let *op) override { - return visit_let(op); + return visit_let(op); } public: From 5783534220cff5f7773aafebeea337cd499ff56f Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Fri, 27 Dec 2024 10:25:07 -0800 Subject: [PATCH 14/16] Fix UB-introducing rewrite in FindIntrinsics (#8539) FindIntrinsics was rewriting i8(rounding_shift_left(i16(foo_i8), 11)) to rounding_shift_left(foo_i8, 11) I.e. it decided it could do the shift in the narrower type. However this isn't correct, because 11 is a valid shift for a 16-bit int, but not for an 8-bit int. The former is zero, the latter gets turned into a poison value because we lower it to llvm's shl. This was discovered by a random failure in test/correctness/lossless_cast.cpp in another PR for seed 826708018. This PR fixes this case by adding a compile-time check that the shift is in-range. For the examples in test/correctness/intrinsics.cpp the shift amount ends up in a let, so making this work on those cases required handling a TODO: tracking the constant integer bounds of variables in scope in FindIntrinsics and therefore also in lossless_cast. --- src/FindIntrinsics.cpp | 63 ++++++++++++++++++++++++++++++--- src/IROperator.cpp | 57 +++++++++++++++-------------- src/IROperator.h | 19 ++++++---- test/correctness/intrinsics.cpp | 18 +++++----- 4 files changed, 109 insertions(+), 48 deletions(-) diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index 4034d7aea629..7d2e522a9937 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -101,11 +101,10 @@ class FindIntrinsics : public IRMutator { Scope let_var_bounds; Expr lossless_cast(Type t, const Expr &e) { - return Halide::Internal::lossless_cast(t, e, &bounds_cache); + return Halide::Internal::lossless_cast(t, e, let_var_bounds, &bounds_cache); } ConstantInterval constant_integer_bounds(const Expr &e) { - // TODO: Use the scope - add let visitors return Halide::Internal::constant_integer_bounds(e, let_var_bounds, &bounds_cache); } @@ -210,6 +209,51 @@ class FindIntrinsics : public IRMutator { return Expr(); } + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { + struct Frame { + const LetOrLetStmt *orig; + Expr new_value; + ScopedBinding bind; + Frame(const LetOrLetStmt *orig, + Expr &&new_value, + ScopedBinding &&bind) + : orig(orig), new_value(std::move(new_value)), bind(std::move(bind)) { + } + }; + std::vector frames; + decltype(op->body) body; + while (op) { + Expr v = mutate(op->value); + ConstantInterval b = constant_integer_bounds(v); + frames.emplace_back(op, + std::move(v), + ScopedBinding(let_var_bounds, op->name, b)); + body = op->body; + op = body.template as(); + } + + body = mutate(body); + + for (const auto &f : reverse_view(frames)) { + if (f.new_value.same_as(f.orig->value) && body.same_as(f.orig->body)) { + body = f.orig; + } else { + body = LetOrLetStmt::make(f.orig->name, f.new_value, body); + } + } + + return body; + } + + Expr visit(const Let *op) override { + return visit_let(op); + } + + Stmt visit(const LetStmt *op) override { + return visit_let(op); + } + Expr visit(const Add *op) override { if (!find_intrinsics_for_type(op->type)) { return IRMutator::visit(op); @@ -697,7 +741,12 @@ class FindIntrinsics : public IRMutator { bool is_saturated = op->value.as() || op->value.as(); Expr a = lossless_cast(op->type, shift->args[0]); Expr b = lossless_cast(op->type.with_code(shift->args[1].type().code()), shift->args[1]); - if (a.defined() && b.defined()) { + // Doing the shift in the narrower type might introduce UB where + // there was no UB before, so we need to make sure b is bounded. + auto b_bounds = constant_integer_bounds(b); + const int max_shift = op->type.bits() - 1; + + if (a.defined() && b.defined() && b_bounds >= -max_shift && b_bounds <= max_shift) { if (!is_saturated || (shift->is_intrinsic(Call::rounding_shift_right) && can_prove(b >= 0)) || (shift->is_intrinsic(Call::rounding_shift_left) && can_prove(b <= 0))) { @@ -1118,8 +1167,12 @@ class SubstituteInWideningLets : public IRMutator { std::string name; Expr new_value; ScopedBinding bind; - Frame(const std::string &name, const Expr &new_value, ScopedBinding &&bind) - : name(name), new_value(new_value), bind(std::move(bind)) { + Frame(const std::string &name, + const Expr &new_value, + ScopedBinding &&bind) + : name(name), + new_value(new_value), + bind(std::move(bind)) { } }; std::vector frames; diff --git a/src/IROperator.cpp b/src/IROperator.cpp index 26869c4d8b15..e2143470e497 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -427,17 +427,20 @@ Expr const_false(int w) { return make_zero(UInt(1, w)); } -Expr lossless_cast(Type t, Expr e, std::map *cache) { +Expr lossless_cast(Type t, + Expr e, + const Scope &scope, + std::map *cache) { if (!e.defined() || t == e.type()) { return e; } else if (t.can_represent(e.type())) { return cast(t, std::move(e)); } else if (const Cast *c = e.as()) { if (c->type.can_represent(c->value.type())) { - return lossless_cast(t, c->value, cache); + return lossless_cast(t, c->value, scope, cache); } } else if (const Broadcast *b = e.as()) { - Expr v = lossless_cast(t.element_of(), b->value, cache); + Expr v = lossless_cast(t.element_of(), b->value, scope, cache); if (v.defined()) { return Broadcast::make(v, b->lanes); } @@ -456,7 +459,7 @@ Expr lossless_cast(Type t, Expr e, std::map } else if (const Shuffle *shuf = e.as()) { std::vector vecs; for (const auto &vec : shuf->vectors) { - vecs.emplace_back(lossless_cast(t.with_lanes(vec.type().lanes()), vec, cache)); + vecs.emplace_back(lossless_cast(t.with_lanes(vec.type().lanes()), vec, scope, cache)); if (!vecs.back().defined()) { return Expr(); } @@ -465,73 +468,73 @@ Expr lossless_cast(Type t, Expr e, std::map } else if (t.is_int_or_uint()) { // Check the bounds. If they're small enough, we can throw narrowing // casts around e, or subterms. - ConstantInterval ci = constant_integer_bounds(e, Scope::empty_scope(), cache); + ConstantInterval ci = constant_integer_bounds(e, scope, cache); if (t.can_represent(ci)) { // There are certain IR nodes where if the result is expressible // using some type, and the args are expressible using that type, // then the operation can just be done in that type. if (const Add *op = e.as()) { - Expr a = lossless_cast(t, op->a, cache); - Expr b = lossless_cast(t, op->b, cache); + Expr a = lossless_cast(t, op->a, scope, cache); + Expr b = lossless_cast(t, op->b, scope, cache); if (a.defined() && b.defined()) { return Add::make(a, b); } } else if (const Sub *op = e.as()) { - Expr a = lossless_cast(t, op->a, cache); - Expr b = lossless_cast(t, op->b, cache); + Expr a = lossless_cast(t, op->a, scope, cache); + Expr b = lossless_cast(t, op->b, scope, cache); if (a.defined() && b.defined()) { return Sub::make(a, b); } } else if (const Mul *op = e.as()) { - Expr a = lossless_cast(t, op->a, cache); - Expr b = lossless_cast(t, op->b, cache); + Expr a = lossless_cast(t, op->a, scope, cache); + Expr b = lossless_cast(t, op->b, scope, cache); if (a.defined() && b.defined()) { return Mul::make(a, b); } } else if (const Min *op = e.as()) { - Expr a = lossless_cast(t, op->a, cache); - Expr b = lossless_cast(t, op->b, cache); + Expr a = lossless_cast(t, op->a, scope, cache); + Expr b = lossless_cast(t, op->b, scope, cache); if (a.defined() && b.defined()) { debug(0) << a << " " << b << "\n"; return Min::make(a, b); } } else if (const Max *op = e.as()) { - Expr a = lossless_cast(t, op->a, cache); - Expr b = lossless_cast(t, op->b, cache); + Expr a = lossless_cast(t, op->a, scope, cache); + Expr b = lossless_cast(t, op->b, scope, cache); if (a.defined() && b.defined()) { return Max::make(a, b); } } else if (const Mod *op = e.as()) { - Expr a = lossless_cast(t, op->a, cache); - Expr b = lossless_cast(t, op->b, cache); + Expr a = lossless_cast(t, op->a, scope, cache); + Expr b = lossless_cast(t, op->b, scope, cache); if (a.defined() && b.defined()) { return Mod::make(a, b); } } else if (const Call *op = Call::as_intrinsic(e, {Call::widening_add, Call::widen_right_add})) { - Expr a = lossless_cast(t, op->args[0], cache); - Expr b = lossless_cast(t, op->args[1], cache); + Expr a = lossless_cast(t, op->args[0], scope, cache); + Expr b = lossless_cast(t, op->args[1], scope, cache); if (a.defined() && b.defined()) { return Add::make(a, b); } } else if (const Call *op = Call::as_intrinsic(e, {Call::widening_sub, Call::widen_right_sub})) { - Expr a = lossless_cast(t, op->args[0], cache); - Expr b = lossless_cast(t, op->args[1], cache); + Expr a = lossless_cast(t, op->args[0], scope, cache); + Expr b = lossless_cast(t, op->args[1], scope, cache); if (a.defined() && b.defined()) { return Sub::make(a, b); } } else if (const Call *op = Call::as_intrinsic(e, {Call::widening_mul, Call::widen_right_mul})) { - Expr a = lossless_cast(t, op->args[0], cache); - Expr b = lossless_cast(t, op->args[1], cache); + Expr a = lossless_cast(t, op->args[0], scope, cache); + Expr b = lossless_cast(t, op->args[1], scope, cache); if (a.defined() && b.defined()) { return Mul::make(a, b); } } else if (const Call *op = Call::as_intrinsic(e, {Call::shift_left, Call::widening_shift_left, Call::shift_right, Call::widening_shift_right})) { - Expr a = lossless_cast(t, op->args[0], cache); - Expr b = lossless_cast(t, op->args[1], cache); + Expr a = lossless_cast(t, op->args[0], scope, cache); + Expr b = lossless_cast(t, op->args[1], scope, cache); if (a.defined() && b.defined()) { - ConstantInterval cb = constant_integer_bounds(b, Scope::empty_scope(), cache); + ConstantInterval cb = constant_integer_bounds(b, scope, cache); if (cb > -t.bits() && cb < t.bits()) { if (op->is_intrinsic({Call::shift_left, Call::widening_shift_left})) { return a << b; @@ -544,7 +547,7 @@ Expr lossless_cast(Type t, Expr e, std::map if (op->op == VectorReduce::Add || op->op == VectorReduce::Min || op->op == VectorReduce::Max) { - Expr v = lossless_cast(t.with_lanes(op->value.type().lanes()), op->value, cache); + Expr v = lossless_cast(t.with_lanes(op->value.type().lanes()), op->value, scope, cache); if (v.defined()) { return VectorReduce::make(op->op, v, op->type.lanes()); } diff --git a/src/IROperator.h b/src/IROperator.h index 2b0ce6d97563..02a69ed053e0 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -11,7 +11,9 @@ #include #include +#include "ConstantInterval.h" #include "Expr.h" +#include "Scope.h" #include "Target.h" #include "Tuple.h" @@ -150,13 +152,16 @@ Expr const_false(int lanes = 1); /** Attempt to cast an expression to a smaller type while provably not losing * information. If it can't be done, return an undefined Expr. * - * Optionally accepts a map that gives the constant bounds of exprs already - * analyzed to avoid redoing work across many calls to lossless_cast. It is not - * safe to use this optional map in contexts where the same Expr object may - * take on a different value. For example: - * (let x = 4 in some_expr_object) + (let x = 5 in the_same_expr_object)). - * It is safe to use it after uniquify_variable_names has been run. */ -Expr lossless_cast(Type t, Expr e, std::map *cache = nullptr); + * Optionally accepts a scope giving the constant bounds of any variables, and a + * map that gives the constant bounds of exprs already analyzed to avoid redoing + * work across many calls to lossless_cast. It is not safe to use this optional + * map in contexts where the same Expr object may take on a different value. For + * example: (let x = 4 in some_expr_object) + (let x = 5 in + * the_same_expr_object)). It is safe to use it after uniquify_variable_names + * has been run. */ +Expr lossless_cast(Type t, Expr e, + const Scope &scope = Scope::empty_scope(), + std::map *cache = nullptr); /** Attempt to negate x without introducing new IR and without overflow. * If it can't be done, return an undefined Expr. */ diff --git a/test/correctness/intrinsics.cpp b/test/correctness/intrinsics.cpp index e5119bd5e1be..757e1360be36 100644 --- a/test/correctness/intrinsics.cpp +++ b/test/correctness/intrinsics.cpp @@ -255,17 +255,17 @@ int main(int argc, char **argv) { check((u64(u32x) + 8) / 16, u64(rounding_shift_right(u32x, 4))); check(u16(min((u64(u32x) + 8) / 16, 65535)), u16_sat(rounding_shift_right(u32x, 4))); - // And with variable shifts. - check(i8(widening_add(i8x, (i8(1) << u8y) / 2) >> u8y), rounding_shift_right(i8x, u8y)); + // And with variable shifts. These won't match unless Halide can statically + // prove it's not an out-of-range shift. + Expr u8yc = Min::make(u8y, make_const(u8y.type(), 7)); + Expr i8yc = Max::make(Min::make(i8y, make_const(i8y.type(), 7)), make_const(i8y.type(), -7)); + check(i8(widening_add(i8x, (i8(1) << u8yc) / 2) >> u8yc), rounding_shift_right(i8x, u8yc)); check((i32x + (i32(1) << u32y) / 2) >> u32y, rounding_shift_right(i32x, u32y)); - - check(i8(widening_add(i8x, (i8(1) << max(i8y, 0)) / 2) >> i8y), rounding_shift_right(i8x, i8y)); + check(i8(widening_add(i8x, (i8(1) << max(i8yc, 0)) / 2) >> i8yc), rounding_shift_right(i8x, i8yc)); check((i32x + (i32(1) << max(i32y, 0)) / 2) >> i32y, rounding_shift_right(i32x, i32y)); - - check(i8(widening_add(i8x, (i8(1) >> min(i8y, 0)) / 2) << i8y), rounding_shift_left(i8x, i8y)); + check(i8(widening_add(i8x, (i8(1) >> min(i8yc, 0)) / 2) << i8yc), rounding_shift_left(i8x, i8yc)); check((i32x + (i32(1) >> min(i32y, 0)) / 2) << i32y, rounding_shift_left(i32x, i32y)); - - check(i8(widening_add(i8x, (i8(1) << -min(i8y, 0)) / 2) << i8y), rounding_shift_left(i8x, i8y)); + check(i8(widening_add(i8x, (i8(1) << -min(i8yc, 0)) / 2) << i8yc), rounding_shift_left(i8x, i8yc)); check((i32x + (i32(1) << -min(i32y, 0)) / 2) << i32y, rounding_shift_left(i32x, i32y)); check((i32x + (i32(1) << max(-i32y, 0)) / 2) << i32y, rounding_shift_left(i32x, i32y)); @@ -372,7 +372,7 @@ int main(int argc, char **argv) { f(x) = cast(x); f.compute_root(); - g(x) = rounding_shift_right(x, 0) + rounding_shift_left(x, 8); + g(x) = rounding_shift_right(f(x), 0) + u8(rounding_shift_left(u16(f(x)), 11)); g.compile_jit(); } From 50354c387ce286ad41665db078342518eb5b87f4 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Fri, 27 Dec 2024 10:25:34 -0800 Subject: [PATCH 15/16] Skip test when code could be using x87 (#8537) * Skip test when code could be using x87 * Add MSVC macro Co-authored-by: Alex Reinking --------- Co-authored-by: Alex Reinking --- test/correctness/saturating_casts.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/correctness/saturating_casts.cpp b/test/correctness/saturating_casts.cpp index 0ce1cfda7da7..7a17006b2e43 100644 --- a/test/correctness/saturating_casts.cpp +++ b/test/correctness/saturating_casts.cpp @@ -290,6 +290,13 @@ void test_one_source() { } int main(int argc, char **argv) { + +#if defined(__i386__) || defined(_M_IX86) + printf("[SKIP] Skipping test because it requires bit-exact int to float casts,\n" + "and on i386 without SSE it is hard to guarantee that the test binary won't use x87 instructions.\n"); + return 0; +#endif + test_one_source(); test_one_source(); test_one_source(); From a9f82db97a77a227c9aafbf8ce4c4d353a367e3f Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Fri, 27 Dec 2024 12:37:55 -0600 Subject: [PATCH 16/16] Rewrite the rfactor scheduling directive (#8490) The old implementation suffered from several serious issues. It duplicated substantial amounts of the logic in ApplySplit.cpp, the way it handled adapting the predicate to the reducing func was unprincipled, and it confused dims and vars in a way that could segfault. It also left the order of pure dimensions unspecified. The new implementation chooses to follow the existing dims list. We now disallow rfactor() on funcs with RVar+Var fused schedules. The implementation also relies on a new or_condition_over_domain helper function and drops the purify() scheduling directive. Fixes #7854 --- .../src/halide/halide_/PyStage.cpp | 2 +- src/ApplySplit.cpp | 9 +- src/BoundsInference.cpp | 24 +- src/Derivative.cpp | 2 +- src/Deserialization.cpp | 2 - src/Func.cpp | 874 ++++++++---------- src/Func.h | 7 +- src/Inline.cpp | 2 - src/Schedule.h | 7 +- src/ScheduleFunctions.cpp | 9 +- src/Serialization.cpp | 2 - src/Solve.cpp | 4 + src/Solve.h | 7 + src/Substitute.h | 11 + src/halide_ir.fbs | 1 - test/common/expect_abort.cpp | 3 + test/error/CMakeLists.txt | 2 + .../rfactor_after_var_and_rvar_fusion.cpp | 25 + test/error/rfactor_fused_var_and_rvar.cpp | 26 + 19 files changed, 463 insertions(+), 556 deletions(-) create mode 100644 test/error/rfactor_after_var_and_rvar_fusion.cpp create mode 100644 test/error/rfactor_fused_var_and_rvar.cpp diff --git a/python_bindings/src/halide/halide_/PyStage.cpp b/python_bindings/src/halide/halide_/PyStage.cpp index b412a6f2b39e..fac47fa3cf1f 100644 --- a/python_bindings/src/halide/halide_/PyStage.cpp +++ b/python_bindings/src/halide/halide_/PyStage.cpp @@ -14,7 +14,7 @@ void define_stage(py::module &m) { .def("dump_argument_list", &Stage::dump_argument_list) .def("name", &Stage::name) - .def("rfactor", (Func(Stage::*)(std::vector>)) & Stage::rfactor, + .def("rfactor", (Func(Stage::*)(const std::vector> &)) & Stage::rfactor, py::arg("preserved")) .def("rfactor", (Func(Stage::*)(const RVar &, const Var &)) & Stage::rfactor, py::arg("r"), py::arg("v")) diff --git a/src/ApplySplit.cpp b/src/ApplySplit.cpp index b6491f063fba..22c3425c02a4 100644 --- a/src/ApplySplit.cpp +++ b/src/ApplySplit.cpp @@ -157,7 +157,6 @@ vector apply_split(const Split &split, const string &prefix, } } break; case Split::RenameVar: - case Split::PurifyRVar: result.emplace_back(prefix + split.old_var, outer, ApplySplitResult::Substitution); result.emplace_back(prefix + split.old_var, outer, ApplySplitResult::LetStmt); break; @@ -167,10 +166,7 @@ vector apply_split(const Split &split, const string &prefix, } vector> compute_loop_bounds_after_split(const Split &split, const string &prefix) { - // Define the bounds on the split dimensions using the bounds - // on the function args. If it is a purify, we should use the bounds - // from the dims instead. - + // Define the bounds on the split dimensions using the bounds on the function args. vector> let_stmts; Expr old_var_extent = Variable::make(Int(32), prefix + split.old_var + ".loop_extent"); @@ -201,9 +197,6 @@ vector> compute_loop_bounds_after_split(const Split &spl let_stmts.emplace_back(prefix + split.outer + ".loop_max", old_var_max); let_stmts.emplace_back(prefix + split.outer + ".loop_extent", old_var_extent); break; - case Split::PurifyRVar: - // Do nothing for purify - break; } return let_stmts; diff --git a/src/BoundsInference.cpp b/src/BoundsInference.cpp index 724adb993afd..aba76f7798ed 100644 --- a/src/BoundsInference.cpp +++ b/src/BoundsInference.cpp @@ -14,6 +14,7 @@ #include #include +#include namespace Halide { namespace Internal { @@ -297,7 +298,6 @@ class BoundsInference : public IRMutator { } // Default case (no specialization) - vector predicates = def.split_predicate(); for (const ReductionVariable &rv : def.schedule().rvars()) { rvars.insert(rv); } @@ -308,23 +308,15 @@ class BoundsInference : public IRMutator { } vecs[1] = def.values(); + vector predicates = def.split_predicate(); for (size_t i = 0; i < result.size(); ++i) { for (const Expr &val : vecs[i]) { - if (!predicates.empty()) { - Expr cond_val = Call::make(val.type(), - Internal::Call::if_then_else, - {likely(predicates[0]), val}, - Internal::Call::PureIntrinsic); - for (size_t i = 1; i < predicates.size(); ++i) { - cond_val = Call::make(cond_val.type(), - Internal::Call::if_then_else, - {likely(predicates[i]), cond_val}, - Internal::Call::PureIntrinsic); - } - result[i].emplace_back(const_true(), cond_val); - } else { - result[i].emplace_back(const_true(), val); - } + Expr cond_val = std::accumulate( + predicates.begin(), predicates.end(), val, + [](const auto &acc, const auto &pred) { + return Call::make(acc.type(), Call::if_then_else, {likely(pred), acc}, Call::PureIntrinsic); + }); + result[i].emplace_back(const_true(), cond_val); } } diff --git a/src/Derivative.cpp b/src/Derivative.cpp index e64fb4ada94b..05d3168c95d3 100644 --- a/src/Derivative.cpp +++ b/src/Derivative.cpp @@ -1532,7 +1532,7 @@ void ReverseAccumulationVisitor::propagate_halide_function_call( // f(r.x) = ... && r is associative // => f(x) = ... if (var != nullptr && var->reduction_domain.defined() && - var->reduction_domain.split_predicate().empty()) { + is_const_one(var->reduction_domain.predicate())) { ReductionDomain rdom = var->reduction_domain; int rvar_id = -1; for (int rid = 0; rid < (int)rdom.domain().size(); rid++) { diff --git a/src/Deserialization.cpp b/src/Deserialization.cpp index b915a507f090..b6f49cb1bf43 100644 --- a/src/Deserialization.cpp +++ b/src/Deserialization.cpp @@ -368,8 +368,6 @@ Split::SplitType Deserializer::deserialize_split_type(Serialize::SplitType split return Split::SplitType::RenameVar; case Serialize::SplitType::FuseVars: return Split::SplitType::FuseVars; - case Serialize::SplitType::PurifyRVar: - return Split::SplitType::PurifyRVar; default: user_error << "unknown split type " << (int)split_type << "\n"; return Split::SplitType::SplitVar; diff --git a/src/Func.cpp b/src/Func.cpp index 8ffa9cb1e563..e87b14c89c9a 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -1,6 +1,8 @@ #include #include #include +#include +#include #include #ifdef _MSC_VER @@ -35,8 +37,12 @@ namespace Halide { using std::map; using std::ofstream; +using std::optional; using std::pair; using std::string; +using std::tuple; +using std::unordered_map; +using std::unordered_set; using std::vector; using namespace Internal; @@ -425,7 +431,6 @@ void check_for_race_conditions_in_split_with_blend(const StageSchedule &sched) { } break; case Split::RenameVar: - case Split::PurifyRVar: if (parallel.count(split.outer)) { parallel.insert(split.old_var); } @@ -448,7 +453,6 @@ void check_for_race_conditions_in_split_with_blend(const StageSchedule &sched) { } break; case Split::RenameVar: - case Split::PurifyRVar: if (parallel.count(split.old_var)) { parallel.insert(split.outer); } @@ -602,530 +606,427 @@ class SubstituteSelfReference : public IRMutator { /** Substitute all self-reference calls to 'func' with 'substitute' which * args (LHS) is the old args (LHS) plus 'new_args' in that order. * Expect this method to be called on the value (RHS) of an update definition. */ -Expr substitute_self_reference(Expr val, const string &func, const Function &substitute, - const vector &new_args) { +vector substitute_self_reference(const vector &values, const string &func, + const Function &substitute, const vector &new_args) { SubstituteSelfReference subs(func, substitute, new_args); - val = subs.mutate(val); - return val; -} - -// Substitute the occurrence of 'name' in 'exprs' with 'value'. -void substitute_var_in_exprs(const string &name, const Expr &value, vector &exprs) { - for (auto &expr : exprs) { - expr = substitute(name, value, expr); + vector result; + for (const auto &val : values) { + result.push_back(subs.mutate(val)); } + return result; } -void apply_split_result(const vector> &bounds_let_stmts, - const vector &splits_result, - vector &predicates, vector &args, - vector &values) { - - for (const auto &res : splits_result) { - switch (res.type) { - case ApplySplitResult::Substitution: - case ApplySplitResult::LetStmt: - // Apply substitutions to the list of predicates, args, and values. - // Make sure we substitute in all the let stmts as well since we are - // not going to add them to the exprs. - substitute_var_in_exprs(res.name, res.value, predicates); - substitute_var_in_exprs(res.name, res.value, args); - substitute_var_in_exprs(res.name, res.value, values); - break; - default: - internal_assert(res.type == ApplySplitResult::Predicate); - predicates.push_back(res.value); - break; - } - } +} // anonymous namespace - // Make sure we substitute in all the let stmts from 'bounds_let_stmts' - // since we are not going to add them to the exprs. - for (const auto &let : bounds_let_stmts) { - substitute_var_in_exprs(let.first, let.second, predicates); - substitute_var_in_exprs(let.first, let.second, args); - substitute_var_in_exprs(let.first, let.second, values); - } +Func Stage::rfactor(const RVar &r, const Var &v) { + definition.schedule().touched() = true; + return rfactor({{r, v}}); } -/** Apply split directives on the reduction variables. Remove the old RVar from - * the list and add the split result (inner and outer RVars) to the list. Add - * new predicates corresponding to the TailStrategy to the RDom predicate list. */ -bool apply_split(const Split &s, vector &rvars, - vector &predicates, vector &args, - vector &values, map &dim_extent_alignment) { - internal_assert(s.split_type == Split::SplitVar); - const auto it = std::find_if(rvars.begin(), rvars.end(), - [&s](const ReductionVariable &rv) { return (s.old_var == rv.var); }); - - Expr old_max, old_min, old_extent; - - if (it != rvars.end()) { - debug(4) << " Splitting " << it->var << " into " << s.outer << " and " << s.inner << "\n"; - - old_max = simplify(it->min + it->extent - 1); - old_min = it->min; - old_extent = it->extent; - - it->var = s.inner; - it->min = 0; - it->extent = s.factor; - - rvars.insert(it + 1, {s.outer, 0, simplify((old_extent - 1 + s.factor) / s.factor)}); - - vector splits_result = apply_split(s, "", dim_extent_alignment); - vector> bounds_let_stmts = compute_loop_bounds_after_split(s, ""); - apply_split_result(bounds_let_stmts, splits_result, predicates, args, values); +// Helpers for rfactor implementation +namespace { - return true; +optional find_dim(const vector &items, const VarOrRVar &v) { + const auto has_v = std::find_if(items.begin(), items.end(), [&](auto &x) { + return dim_match(x, v); + }); + return has_v == items.end() ? std::nullopt : std::make_optional(*has_v); +} + +using SubstitutionMap = std::map; + +/** This is a helper function for building up a substitution map that + * corresponds to pushing down a nest of lets. The lets should be fed + * to this function from innermost to outermost. This is equivalent to + * building a let-nest as a one-hole context and then simplifying. + * + * This looks like it might be quadratic or worse, and technically it is, + * but this isn't a problem for the way it is used inside rfactor. There + * are only a few uses: + * + * 1. Remapping preserved RVars to new RVars + * 2. Remapping factored RVars to new Vars + * 3. Filling the holes in the associative template + * 4. Accumulating the lets from ApplySplit + * + * These are naturally bounded by O(#splits + #dims) which is quite small + * in practice. Classes (1) and (2) cannot blow up expressions since they + * simply rename variables. Class (3) cannot blow up expressions either + * since nothing else can refer to the holes. That leaves only class (1). + * Fortunately, the lets generated by splits are benign. Split factors can't + * refer to RVars, and we won't see the consumed RVars in another split. So + * in total, this avoids any sort of exponentially sized substitution. + * + * @param subst The existing let nest (represented by a SubstitutionMap). + * @param name The name to bind, cannot already exist in the nest. + * @param value The value to bind. Will be substituted into nested values. + */ +void add_let(SubstitutionMap &subst, const string &name, const Expr &value) { + internal_assert(!subst.count(name)) << "would shadow " << name << " in let nest.\n" + << "\tPresent value: " << subst[name] << "\n" + << "\tProposed value: " << value; + for (auto &[_, e] : subst) { + e = substitute(name, value, e); + } + subst.emplace(name, value); +} + +pair project_rdom(const vector &dims, const ReductionDomain &rdom, const vector &splits) { + // The bounds projections maps expressions that reference the old RDom + // bounds to expressions that reference the new RDom bounds (from dims). + // We call this a projection because we are computing the symbolic image + // of the N-dimensional RDom (dimensionality including splits) in the + // M < N - dimensional result. + SubstitutionMap bounds_projection{}; + for (const Split &split : reverse_view(splits)) { + for (const auto &[name, value] : compute_loop_bounds_after_split(split, "")) { + add_let(bounds_projection, name, value); + } + } + for (const auto &[var, min, extent] : rdom.domain()) { + add_let(bounds_projection, var + ".loop_min", min); + add_let(bounds_projection, var + ".loop_max", min + extent - 1); + add_let(bounds_projection, var + ".loop_extent", extent); + } + + // Build the new RDom from the bounds_projection. + vector new_rvars; + for (const Dim &dim : dims) { + const Expr new_min = simplify(bounds_projection.at(dim.var + ".loop_min")); + const Expr new_extent = simplify(bounds_projection.at(dim.var + ".loop_extent")); + new_rvars.push_back(ReductionVariable{dim.var, new_min, new_extent}); + } + ReductionDomain new_rdom{new_rvars}; + new_rdom.where(rdom.predicate()); + + // Compute a mapping from old dimensions to equivalent values using only + // the new dimensions. For example, if we have an RDom {{0, 20}} and we + // split r.x by 2 into r.xo and r.xi, then this map will contain: + // r.x ~> 2 * r.xo + r.xi + // Certain split tail cases can place additional predicates on the RDom. + // These are handled here, too. + SubstitutionMap dim_projection{}; + SubstitutionMap dim_extent_alignment{}; + for (const auto &[var, _, extent] : rdom.domain()) { + dim_extent_alignment[var] = extent; + } + for (const Split &split : splits) { + for (const auto &result : apply_split(split, "", dim_extent_alignment)) { + switch (result.type) { + case ApplySplitResult::LetStmt: + add_let(dim_projection, result.name, substitute(bounds_projection, result.value)); + break; + case ApplySplitResult::PredicateCalls: + case ApplySplitResult::PredicateProvides: + case ApplySplitResult::Predicate: + new_rdom.where(substitute(bounds_projection, result.value)); + break; + case ApplySplitResult::Substitution: + case ApplySplitResult::SubstitutionInCalls: + case ApplySplitResult::SubstitutionInProvides: + case ApplySplitResult::BlendProvides: + // The lets returned by ApplySplit are sufficient + break; + } + } } - return false; -} - -/** Apply fuse directives on the reduction variables. Remove the - * fused RVars from the list and add the fused RVar to the list. */ -bool apply_fuse(const Split &s, vector &rvars, - vector &predicates, vector &args, - vector &values, map &dim_extent_alignment) { - internal_assert(s.split_type == Split::FuseVars); - const auto &iter_outer = std::find_if(rvars.begin(), rvars.end(), - [&s](const ReductionVariable &rv) { return (s.outer == rv.var); }); - const auto &iter_inner = std::find_if(rvars.begin(), rvars.end(), - [&s](const ReductionVariable &rv) { return (s.inner == rv.var); }); - - Expr inner_min, inner_extent, outer_min, outer_extent; - if ((iter_outer != rvars.end()) && (iter_inner != rvars.end())) { - debug(4) << " Fusing " << s.outer << " and " << s.inner << " into " << s.old_var << "\n"; - - inner_min = iter_inner->min; - inner_extent = iter_inner->extent; - outer_min = iter_outer->min; - outer_extent = iter_outer->extent; - - Expr extent = iter_outer->extent * iter_inner->extent; - iter_outer->var = s.old_var; - iter_outer->min = 0; - iter_outer->extent = extent; - rvars.erase(iter_inner); - - vector splits_result = apply_split(s, "", dim_extent_alignment); - vector> bounds_let_stmts = compute_loop_bounds_after_split(s, ""); - apply_split_result(bounds_let_stmts, splits_result, predicates, args, values); - - return true; + for (const auto &rv : new_rdom.domain()) { + add_let(dim_projection, rv.var, Variable::make(Int(32), rv.var, new_rdom)); } - return false; + return {new_rdom, dim_projection}; } -/** Apply purify directives on the reduction variables and predicates. Purify - * replace a RVar with a Var, thus, the RVar needs to be removed from the list. - * Any reference to the RVar in the predicates will be replaced with reference - * to a Var. */ -bool apply_purify(const Split &s, vector &rvars, - vector &predicates, vector &args, - vector &values, map &dim_extent_alignment) { - internal_assert(s.split_type == Split::PurifyRVar); - const auto &iter = std::find_if(rvars.begin(), rvars.end(), - [&s](const ReductionVariable &rv) { return (s.old_var == rv.var); }); - if (iter != rvars.end()) { - debug(4) << " Purify RVar " << iter->var << " into Var " << s.outer - << ", deleting it from the rvars list\n"; - rvars.erase(iter); +} // namespace - vector splits_result = apply_split(s, "", dim_extent_alignment); - vector> bounds_let_stmts = compute_loop_bounds_after_split(s, ""); - apply_split_result(bounds_let_stmts, splits_result, predicates, args, values); +pair, vector> Stage::rfactor_validate_args(const std::vector> &preserved, const AssociativeOp &prover_result) { + const vector &dims = definition.schedule().dims(); - return true; - } - return false; -} + user_assert(prover_result.associative()) + << "In schedule for " << name() << ": can't perform rfactor() " + << "because we can't prove associativity of the operator\n" + << dump_argument_list(); -/** Apply rename directives on the reduction variables. */ -bool apply_rename(const Split &s, vector &rvars, - vector &predicates, vector &args, - vector &values, map &dim_extent_alignment) { - internal_assert(s.split_type == Split::RenameVar); - const auto &iter = std::find_if(rvars.begin(), rvars.end(), - [&s](const ReductionVariable &rv) { return (s.old_var == rv.var); }); - if (iter != rvars.end()) { - debug(4) << " Renaming " << iter->var << " into " << s.outer << "\n"; - iter->var = s.outer; + unordered_set is_rfactored; + for (const auto &[rv, v] : preserved) { + // Check that the RVars are in the dims list + const auto &rv_dim = find_dim(dims, rv); + user_assert(rv_dim && rv_dim->is_rvar()) + << "In schedule for " << name() << ": can't perform rfactor() " + << "on " << rv.name() << " since either it is not in the reduction " + << "domain, or has already been consumed by another scheduling directive\n" + << dump_argument_list(); - vector splits_result = apply_split(s, "", dim_extent_alignment); - vector> bounds_let_stmts = compute_loop_bounds_after_split(s, ""); - apply_split_result(bounds_let_stmts, splits_result, predicates, args, values); + is_rfactored.insert(rv_dim->var); - return true; + // Check that the new pure Vars we used to rename the RVar aren't already in the dims list + user_assert(!find_dim(dims, v)) + << "In schedule for " << name() << ": can't perform rfactor() " + << "on " << rv.name() << " because the name " << v.name() + << "is already used elsewhere in the Func's schedule.\n" + << dump_argument_list(); } - return false; -} -/** Apply scheduling directives (e.g. split, fuse, etc.) on the reduction - * variables. */ -bool apply_split_directive(const Split &s, vector &rvars, - vector &predicates, vector &args, - vector &values) { - map dim_extent_alignment; - for (const ReductionVariable &rv : rvars) { - dim_extent_alignment[rv.var] = rv.extent; - } + // If the operator is associative but non-commutative, rfactor() on inner + // dimensions (excluding the outer dimensions) is not valid. + if (!prover_result.commutative()) { + optional last_rvar; + for (const auto &d : reverse_view(dims)) { + bool is_inner = is_rfactored.count(d.var) && last_rvar && !is_rfactored.count(last_rvar->var); + user_assert(!is_inner) + << "In schedule for " << name() << ": can't rfactor an inner " + << "dimension " << d.var << " without rfactoring the outer " + << "dimensions, since the operator is non-commutative.\n" + << dump_argument_list(); - vector> rvar_bounds; - for (const ReductionVariable &rv : rvars) { - rvar_bounds.emplace_back(rv.var + ".loop_min", rv.min); - rvar_bounds.emplace_back(rv.var + ".loop_max", simplify(rv.min + rv.extent - 1)); - rvar_bounds.emplace_back(rv.var + ".loop_extent", rv.extent); + if (d.is_rvar()) { + last_rvar = d; + } + } } - bool found = false; - switch (s.split_type) { - case Split::SplitVar: - found = apply_split(s, rvars, predicates, args, values, dim_extent_alignment); - break; - case Split::FuseVars: - found = apply_fuse(s, rvars, predicates, args, values, dim_extent_alignment); - break; - case Split::PurifyRVar: - found = apply_purify(s, rvars, predicates, args, values, dim_extent_alignment); - break; - case Split::RenameVar: - found = apply_rename(s, rvars, predicates, args, values, dim_extent_alignment); - break; + // Check that no Vars were fused into RVars + vector var_splits, rvar_splits; + Scope<> rdims; + for (const ReductionVariable &rv : definition.schedule().rvars()) { + rdims.push(rv.var); } + for (const Split &split : definition.schedule().splits()) { + switch (split.split_type) { + case Split::SplitVar: + if (rdims.contains(split.old_var)) { + rdims.pop(split.old_var); + rdims.push(split.outer); + rdims.push(split.inner); + rvar_splits.emplace_back(split); + } else { + var_splits.emplace_back(split); + } + break; + case Split::FuseVars: + if (rdims.contains(split.outer) || rdims.contains(split.inner)) { + user_assert(rdims.contains(split.outer) && rdims.contains(split.inner)) + << "In schedule for " << name() << ": can't rfactor an Func " + << "that has fused a Var into an RVar: " << split.outer + << ", " << split.inner << "\n" + << dump_argument_list(); - if (found) { - for (const auto &let : rvar_bounds) { - substitute_var_in_exprs(let.first, let.second, predicates); - substitute_var_in_exprs(let.first, let.second, args); - substitute_var_in_exprs(let.first, let.second, values); + rdims.pop(split.outer); + rdims.pop(split.inner); + rdims.push(split.old_var); + rvar_splits.emplace_back(split); + } else { + var_splits.emplace_back(split); + } + break; + case Split::RenameVar: + if (rdims.contains(split.old_var)) { + rdims.pop(split.old_var); + rdims.push(split.outer); + rvar_splits.emplace_back(split); + } else { + var_splits.emplace_back(split); + } + break; } } - return found; + return std::make_pair(std::move(var_splits), std::move(rvar_splits)); } -} // anonymous namespace - -Func Stage::rfactor(const RVar &r, const Var &v) { - definition.schedule().touched() = true; - return rfactor({{r, v}}); -} - -Func Stage::rfactor(vector> preserved) { +Func Stage::rfactor(const vector> &preserved) { user_assert(!definition.is_init()) << "rfactor() must be called on an update definition\n"; definition.schedule().touched() = true; - const string &func_name = function.name(); - vector &args = definition.args(); - vector &values = definition.values(); - - // Figure out which pure vars were used in this update definition. - std::set pure_vars_used; - internal_assert(args.size() == dim_vars.size()); - for (size_t i = 0; i < args.size(); i++) { - if (const Internal::Variable *var = args[i].as()) { - if (var->name == dim_vars[i].name()) { - pure_vars_used.insert(var->name); - } - } - } - // Check whether the operator is associative and determine the operator and // its identity for each value in the definition if it is a Tuple - const auto &prover_result = prove_associativity(func_name, args, values); - - user_assert(prover_result.associative()) - << "Failed to call rfactor() on " << name() - << " since it can't prove associativity of the operator\n"; - internal_assert(prover_result.size() == values.size()); + const auto &prover_result = prove_associativity(function.name(), definition.args(), definition.values()); + + const auto &[var_splits, rvar_splits] = rfactor_validate_args(preserved, prover_result); + + const vector dim_vars_exprs = [&] { + vector result; + result.insert(result.end(), dim_vars.begin(), dim_vars.end()); + return result; + }(); + + // sort preserved by the dimension ordering + vector preserved_rvars; + vector preserved_vars; + vector preserved_rdims; + unordered_set preserved_rdims_set; + vector intermediate_rdims; + { + unordered_map dim_ordering; + for (size_t i = 0; i < definition.schedule().dims().size(); i++) { + dim_ordering.emplace(definition.schedule().dims()[i].var, i); + } - vector &splits = definition.schedule().splits(); - vector &dims = definition.schedule().dims(); - vector &rvars = definition.schedule().rvars(); - vector predicates = definition.split_predicate(); + vector> preserved_with_dims; + for (const auto &[rv, v] : preserved) { + const optional rdim = find_dim(definition.schedule().dims(), rv); + internal_assert(rdim); + preserved_with_dims.emplace_back(rv, v, *rdim); + } - Scope scope; // Contains list of RVars lifted to the intermediate Func - vector rvars_removed; + std::sort(preserved_with_dims.begin(), preserved_with_dims.end(), [&](const auto &lhs, const auto &rhs) { + return dim_ordering.at(std::get<2>(lhs).var) < dim_ordering.at(std::get<2>(rhs).var); + }); - vector is_rfactored(dims.size(), false); - for (const pair &i : preserved) { - const RVar &rv = i.first; - const Var &v = i.second; - { - // Check that the RVar are in the dims list - const auto &iter = std::find_if(dims.begin(), dims.end(), - [&rv](const Dim &dim) { return var_name_match(dim.var, rv.name()); }); - user_assert((iter != dims.end()) && (*iter).is_rvar()) - << "In schedule for " << name() - << ", can't perform rfactor() on " << rv.name() - << " since it is not in the reduction domain\n" - << dump_argument_list(); - is_rfactored[iter - dims.begin()] = true; + for (const auto &[rv, v, dim] : preserved_with_dims) { + preserved_rvars.push_back(rv); + preserved_vars.push_back(v); + preserved_rdims.push_back(dim); + preserved_rdims_set.insert(dim.var); } - { - // Check that the new pure Vars we used to rename the RVar aren't already in the dims list - const auto &iter = std::find_if(dims.begin(), dims.end(), - [&v](const Dim &dim) { return var_name_match(dim.var, v.name()); }); - user_assert(iter == dims.end()) - << "In schedule for " << name() - << ", can't rename the rvars " << rv.name() << " into " << v.name() - << ", since it is already used in this Func's schedule elsewhere.\n" - << dump_argument_list(); - } - } - // If the operator is associative but non-commutative, rfactor() on inner - // dimensions (excluding the outer dimensions) is not valid. - if (!prover_result.commutative()) { - int last_rvar = -1; - for (int i = dims.size() - 1; i >= 0; --i) { - if ((last_rvar != -1) && is_rfactored[i]) { - user_assert(is_rfactored[last_rvar]) - << "In schedule for " << name() - << ", can't rfactor an inner dimension " << dims[i].var - << " without rfactoring the outer dimensions, since the " - << "operator is non-commutative.\n" - << dump_argument_list(); - } - if (dims[i].is_rvar()) { - last_rvar = i; + for (const Dim &dim : definition.schedule().dims()) { + if (dim.is_rvar() && !preserved_rdims_set.count(dim.var)) { + intermediate_rdims.push_back(dim); } } } - // We need to apply the split directives on the reduction vars, so that we can - // correctly lift the RVars not in 'rvars_kept' and distribute the RVars to the - // intermediate and merge Funcs. + // Project the RDom into each side + ReductionDomain intermediate_rdom, preserved_rdom; + SubstitutionMap intermediate_map, preserved_map; { - vector temp; - for (const Split &s : splits) { - // If it's already applied, we should remove it from the split list. - if (!apply_split_directive(s, rvars, predicates, args, values)) { - temp.push_back(s); - } - } - splits = temp; - } - - // Reduction domain of the intermediate update definition - vector intm_rvars; - for (const auto &rv : rvars) { - const auto &iter = std::find_if(preserved.begin(), preserved.end(), - [&rv](const pair &pair) { return var_name_match(rv.var, pair.first.name()); }); - if (iter == preserved.end()) { - intm_rvars.push_back(rv); - scope.push(rv.var, rv.var); - } - } - RDom intm_rdom(intm_rvars); - - // Sort the Rvars kept and their Vars replacement based on the RVars of - // the reduction domain AFTER applying the split directives, so that we - // can have a consistent args order for the update definition of the - // intermediate and new merge Funcs. - std::sort(preserved.begin(), preserved.end(), - [&](const pair &lhs, const pair &rhs) { - const auto &iter_lhs = std::find_if(rvars.begin(), rvars.end(), - [&lhs](const ReductionVariable &rv) { return var_name_match(rv.var, lhs.first.name()); }); - const auto &iter_rhs = std::find_if(rvars.begin(), rvars.end(), - [&rhs](const ReductionVariable &rv) { return var_name_match(rv.var, rhs.first.name()); }); - return iter_lhs < iter_rhs; - }); - // The list of RVars to keep in the new update definition - vector rvars_kept(preserved.size()); - // List of pure Vars to replace the RVars in the intermediate's update definition - vector vars_rename(preserved.size()); - for (size_t i = 0; i < preserved.size(); ++i) { - const auto &val = preserved[i]; - rvars_kept[i] = val.first; - vars_rename[i] = val.second; - } - - // List of RVars for the new reduction domain. Any RVars not in 'rvars_kept' - // are removed from the RDom - { - vector temp; - for (const auto &rv : rvars) { - const auto &iter = std::find_if(rvars_kept.begin(), rvars_kept.end(), - [&rv](const RVar &rvar) { return var_name_match(rv.var, rvar.name()); }); - if (iter != rvars_kept.end()) { - temp.push_back(rv); - } else { - rvars_removed.push_back(rv.var); - } - } - rvars.swap(temp); - } - RDom f_rdom(rvars); - - // Init definition of the intermediate Func + ReductionDomain rdom{definition.schedule().rvars(), definition.predicate(), true}; - // Compute args of the init definition of the intermediate Func. - // Replace the RVars, which are in 'rvars_kept', with the specified new pure - // Vars. Also, add the pure Vars of the original init definition as part of - // the args. - // For example, if we have the following Func f: - // f(x, y) = 10 - // f(r.x, r.y) += h(r.x, r.y) - // Calling f.update(0).rfactor({{r.y, u}}) will generate the following - // intermediate Func: - // f_intm(x, y, u) = 0 - // f_intm(r.x, u, u) += h(r.x, u) - - vector init_args; - init_args.insert(init_args.end(), dim_vars.begin(), dim_vars.end()); - init_args.insert(init_args.end(), vars_rename.begin(), vars_rename.end()); + // Intermediate + std::tie(intermediate_rdom, intermediate_map) = project_rdom(intermediate_rdims, rdom, rvar_splits); + for (size_t i = 0; i < preserved.size(); i++) { + add_let(intermediate_map, preserved_rdims[i].var, preserved_vars[i]); + } + intermediate_rdom.set_predicate(simplify(substitute(intermediate_map, intermediate_rdom.predicate()))); - vector init_vals(values.size()); - for (size_t i = 0; i < init_vals.size(); ++i) { - init_vals[i] = prover_result.pattern.identities[i]; + // Preserved + std::tie(preserved_rdom, preserved_map) = project_rdom(preserved_rdims, rdom, rvar_splits); + Scope intm_rdom; + for (const auto &[var, min, extent] : intermediate_rdom.domain()) { + intm_rdom.push(var, Interval{min, min + extent - 1}); + } + preserved_rdom.set_predicate(or_condition_over_domain(substitute(preserved_map, preserved_rdom.predicate()), intm_rdom)); } - Func intm(func_name + "_intm"); - intm(init_args) = Tuple(init_vals); + // Intermediate func + Func intm(function.name() + "_intm"); - // Args of the update definition of the intermediate Func - vector update_args(args.size() + vars_rename.size()); - - // We need to substitute the reference to the old RDom's RVars with - // the new RDom's RVars. Also, substitute the reference to RVars which - // are in 'rvars_kept' with their corresponding new pure Vars - map substitution_map; - for (size_t i = 0; i < intm_rvars.size(); ++i) { - substitution_map[intm_rvars[i].var] = intm_rdom[i]; - } - for (size_t i = 0; i < vars_rename.size(); i++) { - update_args[i + args.size()] = vars_rename[i]; - RVar rvar_kept = rvars_kept[i]; - // Find the full name of rvar_kept in rvars - const auto &iter = std::find_if(rvars.begin(), rvars.end(), - [&rvar_kept](const ReductionVariable &rv) { return var_name_match(rv.var, rvar_kept.name()); }); - substitution_map[iter->var] = vars_rename[i]; - } - for (size_t i = 0; i < args.size(); i++) { - Expr arg = substitute(substitution_map, args[i]); - update_args[i] = arg; + // Intermediate pure definition + { + vector args = dim_vars_exprs; + args.insert(args.end(), preserved_vars.begin(), preserved_vars.end()); + intm(args) = Tuple(prover_result.pattern.identities); } - // Compute the predicates for the intermediate Func and the new update definition - for (const Expr &pred : predicates) { - Expr subs_pred = substitute(substitution_map, pred); - intm_rdom.where(subs_pred); - if (!expr_uses_vars(pred, scope)) { - // Only keep the predicate that does not depend on the lifted RVars - // (either explicitly or implicitly). For example, if 'rx' is split - // into 'rxo' and 'rxi' and 'rxo' is part of the lifted RVars, we'll - // ignore every predicate that depends on 'rx' - f_rdom.where(pred); + // Intermediate update definition + { + vector args = definition.args(); + args.insert(args.end(), preserved_vars.begin(), preserved_vars.end()); + args = substitute(intermediate_map, args); + + vector values = definition.values(); + values = substitute_self_reference(values, function.name(), intm.function(), preserved_vars); + values = substitute(intermediate_map, values); + intm.function().define_update(args, values, intermediate_rdom); + + // Intermediate schedule + vector intm_dims = definition.schedule().dims(); + + // Replace rvar dims IN the preserved list with their Vars in the INTERMEDIATE Func + for (auto &dim : intm_dims) { + const auto it = std::find_if(preserved_rvars.begin(), preserved_rvars.end(), [&](const auto &rv) { + return dim_match(dim, rv); + }); + if (it != preserved_rvars.end()) { + const auto offset = it - preserved_rvars.begin(); + const auto &var = preserved_vars[offset]; + const auto &pure_dim = find_dim(intm.function().definition().schedule().dims(), var); + internal_assert(pure_dim); + dim = *pure_dim; + } } - } - definition.predicate() = f_rdom.domain().predicate(); - // The update values the intermediate Func should compute - vector update_vals(values.size()); - for (size_t i = 0; i < update_vals.size(); i++) { - Expr val = substitute(substitution_map, values[i]); - // Need to update the self-reference in the update definition to point - // to the new intermediate Func - val = substitute_self_reference(val, func_name, intm.function(), vars_rename); - update_vals[i] = val; - } - // There may not actually be a reference to the RDom in the args or values, - // so we use Function::define_update, which lets pass pass an explicit RDom. - intm.function().define_update(update_args, update_vals, intm_rdom.domain()); - - // Determine the dims and schedule of the update definition of the - // intermediate Func. We copy over the schedule from the original - // update definition (e.g. split, parallelize, vectorize, etc.) - intm.function().update(0).schedule().dims() = dims; - intm.function().update(0).schedule().splits() = splits; - - // Copy over the storage order of the original pure dims - vector &intm_storage_dims = intm.function().schedule().storage_dims(); - internal_assert(intm_storage_dims.size() == - function.schedule().storage_dims().size() + vars_rename.size()); - for (size_t i = 0; i < function.schedule().storage_dims().size(); ++i) { - intm_storage_dims[i] = function.schedule().storage_dims()[i]; - } - - for (size_t i = 0; i < rvars_kept.size(); ++i) { - // Apply the purify directive that replaces the RVar in rvars_kept - // with a pure Var - intm.update(0).purify(rvars_kept[i], vars_rename[i]); - } + // Add factored pure dims to the INTERMEDIATE func just before outermost + unordered_set dims; + for (const auto &dim : intm_dims) { + dims.insert(dim.var); + } + for (const Var &var : preserved_vars) { + const optional &dim = find_dim(intm.function().definition().schedule().dims(), var); + internal_assert(dim) << "Failed to find " << var.name() << " in list of pure dims"; + if (!dims.count(dim->var)) { + intm_dims.insert(intm_dims.end() - 1, *dim); + } + } - // Determine the dims of the new update definition - - // The new update definition needs all the pure vars of the Func, but the - // one we're rfactoring may not have used them all. Add any missing ones to - // the dims list. - - // Add pure Vars from the original init definition to the dims list - // if they are not already in the list - for (const Var &v : dim_vars) { - if (!pure_vars_used.count(v.name())) { - Dim d = {v.name(), ForType::Serial, DeviceAPI::None, DimType::PureVar, Partition::Auto}; - // Insert it just before Var::outermost - dims.insert(dims.end() - 1, d); - } + intm.function().update(0).schedule() = definition.schedule().get_copy(); + intm.function().update(0).schedule().dims() = std::move(intm_dims); + intm.function().update(0).schedule().rvars() = intermediate_rdom.domain(); + intm.function().update(0).schedule().splits() = var_splits; } - // Then, we need to remove lifted RVars from the dims list - for (const string &rv : rvars_removed) { - remove(rv); - } + // Preserved update definition + { + // Replace the current definition with calls to the intermediate func. + vector f_load_args = dim_vars_exprs; + for (const ReductionVariable &rv : preserved_rdom.domain()) { + f_load_args.push_back(Variable::make(Int(32), rv.var, preserved_rdom)); + } - // Define the new update definition which refers to the intermediate Func. - // Using the same example as above, the new update definition is: - // f(x, y) += f_intm(x, y, r.y) + for (size_t i = 0; i < definition.values().size(); ++i) { + if (!prover_result.ys[i].var.empty()) { + Expr r = (definition.values().size() == 1) ? Expr(intm(f_load_args)) : Expr(intm(f_load_args)[i]); + add_let(preserved_map, prover_result.ys[i].var, r); + } - // Args for store in the new update definition - vector f_store_args(dim_vars.size()); - for (size_t i = 0; i < f_store_args.size(); ++i) { - f_store_args[i] = dim_vars[i]; - } - - // Call's args to the intermediate Func in the new update definition - vector f_load_args; - f_load_args.insert(f_load_args.end(), dim_vars.begin(), dim_vars.end()); - for (int i = 0; i < f_rdom.dimensions(); ++i) { - f_load_args.push_back(f_rdom[i]); - } - internal_assert(f_load_args.size() == init_args.size()); + if (!prover_result.xs[i].var.empty()) { + Expr prev_val = Call::make(intm.types()[i], function.name(), + dim_vars_exprs, Call::CallType::Halide, + FunctionPtr(), i); + add_let(preserved_map, prover_result.xs[i].var, prev_val); + } else { + user_warning << "Update definition of " << name() << " at index " << i + << " doesn't depend on the previous value. This isn't a" + << " reduction operation\n"; + } + } - // Update value of the new update definition. It loads values from - // the intermediate Func. - vector f_values(values.size()); + vector reducing_dims; + { + // Remove rvar dims NOT IN the preserved list from the REDUCING Func + for (const auto &dim : definition.schedule().dims()) { + if (!dim.is_rvar() || preserved_rdims_set.count(dim.var)) { + reducing_dims.push_back(dim); + } + } - // There might be cross-dependencies between tuple elements, so we need - // to collect all substitutions first. - map replacements; - for (size_t i = 0; i < f_values.size(); ++i) { - if (!prover_result.ys[i].var.empty()) { - Expr r = (values.size() == 1) ? Expr(intm(f_load_args)) : Expr(intm(f_load_args)[i]); - replacements.emplace(prover_result.ys[i].var, r); + // Add missing pure vars to the REDUCING func just before outermost. + // This is necessary whenever the update does not reference one of the + // pure variables. For instance, factoring a histogram (clamps elided): + // g(x) = 0; g(f(r.x, r.y)) += 1; + // Func intm = g.rfactor(r.y, u); + // Here we generate an intermediate func intm that looks like: + // intm(x, u) = 0; intm(f(r.x, u), u) += 1; + // And we need the reducing func to be: + // g(x) += intm(x, r.y); + // But x was not referenced in the original update definition, so that + // dimension is added here. + for (size_t i = 0; i < dim_vars.size(); i++) { + if (!expr_uses_var(definition.args()[i], dim_vars[i].name())) { + Dim d = {dim_vars[i].name(), ForType::Serial, DeviceAPI::None, DimType::PureVar, Partition::Auto}; + reducing_dims.insert(reducing_dims.end() - 1, d); + } + } } - if (!prover_result.xs[i].var.empty()) { - Expr prev_val = Call::make(intm.types()[i], func_name, - f_store_args, Call::CallType::Halide, - FunctionPtr(), i); - replacements.emplace(prover_result.xs[i].var, prev_val); - } else { - user_warning << "Update definition of " << name() << " at index " << i - << " doesn't depend on the previous value. This isn't a" - << " reduction operation\n"; - } - } - for (size_t i = 0; i < f_values.size(); ++i) { - f_values[i] = substitute(replacements, prover_result.pattern.ops[i]); + definition.args() = dim_vars_exprs; + definition.values() = substitute(preserved_map, prover_result.pattern.ops); + definition.predicate() = preserved_rdom.predicate(); + definition.schedule().dims() = std::move(reducing_dims); + definition.schedule().rvars() = preserved_rdom.domain(); + definition.schedule().splits() = var_splits; } - // Update the definition - args.swap(f_store_args); - values.swap(f_values); - return intm; } @@ -1187,7 +1088,7 @@ void Stage::split(const string &old, const string &outer, const string &inner, c bool round_up_ok = !exact; if (round_up_ok && !definition.is_init()) { // If it's the outermost split in this dimension, RoundUp - // is OK. Otherwise we need GuardWithIf to avoid + // is OK. Otherwise, we need GuardWithIf to avoid // recomputing values in the case where the inner split // factor does not divide the outer split factor. std::set inner_vars; @@ -1200,7 +1101,6 @@ void Stage::split(const string &old, const string &outer, const string &inner, c } break; case Split::RenameVar: - case Split::PurifyRVar: if (inner_vars.count(s.old_var)) { inner_vars.insert(s.outer); } @@ -1224,7 +1124,7 @@ void Stage::split(const string &old, const string &outer, const string &inner, c bool predicate_loads_ok = !exact; if (predicate_loads_ok && tail == TailStrategy::PredicateLoads) { // If it's the outermost split in this dimension, PredicateLoads - // is OK. Otherwise we can't prove it's safe. + // is OK. Otherwise, we can't prove it's safe. std::set inner_vars; for (const Split &s : definition.schedule().splits()) { switch (s.split_type) { @@ -1235,7 +1135,6 @@ void Stage::split(const string &old, const string &outer, const string &inner, c } break; case Split::RenameVar: - case Split::PurifyRVar: if (inner_vars.count(s.old_var)) { inner_vars.insert(s.outer); } @@ -1297,7 +1196,6 @@ void Stage::split(const string &old, const string &outer, const string &inner, c } break; case Split::RenameVar: - case Split::PurifyRVar: if (it != descends_from_shiftinwards_outer.end()) { descends_from_shiftinwards_outer[s.outer] = it->second; } @@ -1484,46 +1382,6 @@ void Stage::specialize_fail(const std::string &message) { s.failure_message = message; } -Stage &Stage::purify(const VarOrRVar &old_var, const VarOrRVar &new_var) { - user_assert(old_var.is_rvar && !new_var.is_rvar) - << "In schedule for " << name() - << ", can't rename " << (old_var.is_rvar ? "RVar " : "Var ") << old_var.name() - << " to " << (new_var.is_rvar ? "RVar " : "Var ") << new_var.name() - << "; purify must take a RVar as old_Var and a Var as new_var\n"; - - debug(4) << "In schedule for " << name() << ", purify RVar " - << old_var.name() << " to Var " << new_var.name() << "\n"; - - StageSchedule &schedule = definition.schedule(); - - // Replace the old dimension with the new dimensions in the dims list - bool found = false; - string old_name, new_name = new_var.name(); - vector &dims = schedule.dims(); - - for (size_t i = 0; (!found) && i < dims.size(); i++) { - if (dim_match(dims[i], old_var)) { - found = true; - old_name = dims[i].var; - dims[i].var = new_name; - dims[i].dim_type = DimType::PureVar; - } - } - - if (!found) { - user_error - << "In schedule for " << name() - << ", could not find rename dimension: " - << old_var.name() - << "\n" - << dump_argument_list(); - } - - Split split = {old_name, new_name, "", 1, false, TailStrategy::RoundUp, Split::PurifyRVar}; - definition.schedule().splits().push_back(split); - return *this; -} - void Stage::remove(const string &var) { debug(4) << "In schedule for " << name() << ", remove " << var << "\n"; @@ -1601,7 +1459,6 @@ void Stage::remove(const string &var) { } break; case Split::RenameVar: - case Split::PurifyRVar: debug(4) << " replace/rename " << split.old_var << " into " << split.outer << "\n"; if (should_remove(split.outer)) { @@ -1690,7 +1547,6 @@ Stage &Stage::rename(const VarOrRVar &old_var, const VarOrRVar &new_var) { break; case Split::SplitVar: case Split::RenameVar: - case Split::PurifyRVar: if (split.inner == old_name) { split.inner = new_name; found = true; diff --git a/src/Func.h b/src/Func.h index ae739b8dc538..d0d566a2e1c2 100644 --- a/src/Func.h +++ b/src/Func.h @@ -60,6 +60,7 @@ struct VarOrRVar { class ImageParam; namespace Internal { +struct AssociativeOp; class Function; struct Split; struct StorageDim; @@ -81,7 +82,6 @@ class Stage { void split(const std::string &old, const std::string &outer, const std::string &inner, const Expr &factor, bool exact, TailStrategy tail); void remove(const std::string &var); - Stage &purify(const VarOrRVar &old_name, const VarOrRVar &new_name); const std::vector &storage_dims() const { return function.schedule().storage_dims(); @@ -89,6 +89,9 @@ class Stage { Stage &compute_with(LoopLevel loop_level, const std::map &align); + std::pair, std::vector> + rfactor_validate_args(const std::vector> &preserved, const Internal::AssociativeOp &prover_result); + public: Stage(Internal::Function f, Internal::Definition d, size_t stage_index) : function(std::move(f)), definition(std::move(d)), stage_index(stage_index) { @@ -184,7 +187,7 @@ class Stage { * */ // @{ - Func rfactor(std::vector> preserved); + Func rfactor(const std::vector> &preserved); Func rfactor(const RVar &r, const Var &v); // @} diff --git a/src/Inline.cpp b/src/Inline.cpp index 54399cf77b76..31b6efcdf749 100644 --- a/src/Inline.cpp +++ b/src/Inline.cpp @@ -76,8 +76,6 @@ void validate_schedule_inlined_function(Function f) { << split.inner << " because " << f.name() << " is scheduled inline.\n"; - break; - case Split::PurifyRVar: break; } } diff --git a/src/Schedule.h b/src/Schedule.h index ea2692752a9e..906dbe6c7b5e 100644 --- a/src/Schedule.h +++ b/src/Schedule.h @@ -332,8 +332,7 @@ struct Split { enum SplitType { SplitVar = 0, RenameVar, - FuseVars, - PurifyRVar }; + FuseVars }; // If split_type is Rename, then this is just a renaming of the // old_var to the outer and not a split. The inner var should @@ -341,10 +340,6 @@ struct Split { // the same list as splits so that ordering between them is // respected. - // If split type is Purify, this replaces the old_var RVar to - // the outer Var. The inner var should be ignored, and factor - // should be one. - // If split_type is Fuse, then this does the opposite of a // split, it joins the outer and inner into the old_var. SplitType split_type; diff --git a/src/ScheduleFunctions.cpp b/src/ScheduleFunctions.cpp index c7a257dd085e..02f5553f748c 100644 --- a/src/ScheduleFunctions.cpp +++ b/src/ScheduleFunctions.cpp @@ -178,7 +178,6 @@ Stmt build_loop_nest( const auto &dims = func.args(); const auto &func_s = func.schedule(); const auto &stage_s = def.schedule(); - const auto &predicates = def.split_predicate(); // We'll build it from inside out, starting from the body, // then wrapping it in for loops. @@ -306,7 +305,7 @@ Stmt build_loop_nest( } // Put all the reduction domain predicates into the containers vector. - for (Expr pred : predicates) { + for (Expr pred : def.split_predicate()) { pred = qualify(prefix, pred); // Add a likely qualifier if there isn't already one if (Call::as_intrinsic(pred, {Call::likely, Call::likely_if_innermost})) { @@ -413,8 +412,7 @@ Stmt build_loop_nest( } // Define the bounds on the split dimensions using the bounds - // on the function args. If it is a purify, we should use the bounds - // from the dims instead. + // on the function args. for (const Split &split : reverse_view(splits)) { vector> let_stmts = compute_loop_bounds_after_split(split, prefix); for (const auto &let_stmt : let_stmts) { @@ -2229,7 +2227,7 @@ bool validate_schedule(Function f, const Stmt &s, const Target &target, bool is_ // // However, there are four types of Split, and the concept of a child var varies across them: // - For a vanilla split, inner and outer are the children and old_var is the parent. - // - For rename and purify, the outer is the child and the inner is meaningless. + // - For rename, the outer is the child and the inner is meaningless. // - For fuse, old_var is the child and inner/outer are the parents. // // (@abadams comments: "I acknowledge that this is gross and should be refactored.") @@ -2249,7 +2247,6 @@ bool validate_schedule(Function f, const Stmt &s, const Target &target, bool is_ } break; case Split::RenameVar: - case Split::PurifyRVar: if (parallel_vars.count(split.outer)) { parallel_vars.insert(split.old_var); } diff --git a/src/Serialization.cpp b/src/Serialization.cpp index 15722d878974..d731d9c9d85c 100644 --- a/src/Serialization.cpp +++ b/src/Serialization.cpp @@ -338,8 +338,6 @@ Serialize::SplitType Serializer::serialize_split_type(const Split::SplitType &sp return Serialize::SplitType::RenameVar; case Split::SplitType::FuseVars: return Serialize::SplitType::FuseVars; - case Split::SplitType::PurifyRVar: - return Serialize::SplitType::PurifyRVar; default: user_error << "Unsupported split type\n"; return Serialize::SplitType::SplitVar; diff --git a/src/Solve.cpp b/src/Solve.cpp index 3f124601345a..20d6f5200101 100644 --- a/src/Solve.cpp +++ b/src/Solve.cpp @@ -1239,6 +1239,10 @@ Expr and_condition_over_domain(const Expr &e, const Scope &varying) { return simplify(bounds.min); } +Expr or_condition_over_domain(const Expr &c, const Scope &varying) { + return simplify(!and_condition_over_domain(simplify(!c), varying)); +} + // Testing code namespace { diff --git a/src/Solve.h b/src/Solve.h index ff5124e508c6..4d06fda47d6b 100644 --- a/src/Solve.h +++ b/src/Solve.h @@ -47,6 +47,13 @@ Interval solve_for_inner_interval(const Expr &c, const std::string &variable); * 'and' over the vector lanes, and return a scalar result. */ Expr and_condition_over_domain(const Expr &c, const Scope &varying); +/** Take a conditional that includes variables that vary over some + * domain, and convert it to a weaker (less frequently false) condition + * that doesn't depend on those variables. Formally, the input expr + * implies the output expr. Note that this function might be unable to + * provide a better response than simply const_true(). */ +Expr or_condition_over_domain(const Expr &c, const Scope &varying); + void solve_test(); } // namespace Internal diff --git a/src/Substitute.h b/src/Substitute.h index 22bdf640b7a8..e514fda40359 100644 --- a/src/Substitute.h +++ b/src/Substitute.h @@ -6,6 +6,7 @@ * Defines methods for substituting out variables in expressions and * statements. */ +#include #include #include "Expr.h" @@ -37,6 +38,16 @@ Expr substitute(const Expr &find, const Expr &replacement, const Expr &expr); Stmt substitute(const Expr &find, const Expr &replacement, const Stmt &stmt); // @} +/** Substitute a container of Exprs or Stmts out of place */ +template +T substitute(const std::map &replacements, const T &container) { + T output; + std::transform(container.begin(), container.end(), std::back_inserter(output), [&](const auto &expr_or_stmt) { + return substitute(replacements, expr_or_stmt); + }); + return output; +} + /** Substitutions where the IR may be a general graph (and not just a * DAG). */ // @{ diff --git a/src/halide_ir.fbs b/src/halide_ir.fbs index efc465cbee82..499488ce8b95 100644 --- a/src/halide_ir.fbs +++ b/src/halide_ir.fbs @@ -548,7 +548,6 @@ enum SplitType: ubyte { SplitVar, RenameVar, FuseVars, - PurifyRVar, } table Split { diff --git a/test/common/expect_abort.cpp b/test/common/expect_abort.cpp index cb09a7242921..fec89b0913b7 100644 --- a/test/common/expect_abort.cpp +++ b/test/common/expect_abort.cpp @@ -19,6 +19,9 @@ auto handler = ([]() { << std::flush; suppress_abort = false; std::abort(); // We should never EXPECT an internal error + } catch (const Halide::Error &e) { + std::cerr << e.what() << "\n" + << std::flush; } catch (const std::exception &e) { std::cerr << e.what() << "\n" << std::flush; diff --git a/test/error/CMakeLists.txt b/test/error/CMakeLists.txt index 0478c3b11087..5272b2717de7 100644 --- a/test/error/CMakeLists.txt +++ b/test/error/CMakeLists.txt @@ -94,6 +94,8 @@ tests(GROUPS error require_fail.cpp reuse_var_in_schedule.cpp reused_args.cpp + rfactor_after_var_and_rvar_fusion.cpp + rfactor_fused_var_and_rvar.cpp rfactor_inner_dim_non_commutative.cpp round_up_and_blend_race.cpp run_with_large_stack_throws.cpp diff --git a/test/error/rfactor_after_var_and_rvar_fusion.cpp b/test/error/rfactor_after_var_and_rvar_fusion.cpp new file mode 100644 index 000000000000..acda4e4bb6fb --- /dev/null +++ b/test/error/rfactor_after_var_and_rvar_fusion.cpp @@ -0,0 +1,25 @@ +#include "Halide.h" + +using namespace Halide; + +int main(int argc, char **argv) { + Func f{"f"}; + RDom r({{0, 5}, {0, 5}, {0, 5}}, "r"); + Var x{"x"}, y{"y"}; + f(x, y) = 0; + f(x, y) += r.x + r.y + r.z; + + RVar rxy{"rxy"}, yrz{"yrz"}; + Var z{"z"}; + + // Error: In schedule for f.update(0), can't perform rfactor() after fusing y and r$z + f.update() + .fuse(r.x, r.y, rxy) + .fuse(r.z, y, yrz) + .rfactor(rxy, z); + + f.print_loop_nest(); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/rfactor_fused_var_and_rvar.cpp b/test/error/rfactor_fused_var_and_rvar.cpp new file mode 100644 index 000000000000..64a79c269690 --- /dev/null +++ b/test/error/rfactor_fused_var_and_rvar.cpp @@ -0,0 +1,26 @@ +#include "Halide.h" + +using namespace Halide; + +int main(int argc, char **argv) { + Func f{"f"}; + RDom r({{0, 5}, {0, 5}, {0, 5}}, "r"); + Var x{"x"}, y{"y"}; + f(x, y) = 0; + f(x, y) += r.x + r.y + r.z; + + RVar rxy{"rxy"}, yrz{"yrz"}, yr{"yr"}; + Var z{"z"}; + + // Error: In schedule for f.update(0), can't perform rfactor() after fusing r$z and y + f.update() + .fuse(r.x, r.y, rxy) + .fuse(y, r.z, yrz) + .fuse(rxy, yrz, yr) + .rfactor(yr, z); + + f.print_loop_nest(); + + printf("Success!\n"); + return 0; +}