From 7721411aa1bc38825661ade625882f55681db768 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 23 Dec 2024 14:12:53 -0800 Subject: [PATCH] Fix UB-introducing rewrite in FindIntrinsics FindIntrinsics was rewriting i8(rounding_shift_left(i16(foo_i8), 11)) to rounding_shift_left(foo_i8, 11) I.e. it decided it could do the shift in the narrower type. However this isn't correct, because 11 is a valid shift for a 16-bit int, but not for an 8-bit int. The former is zero, the latter gets turned into a poison value because we lower it to llvm's shl. This was discovered by a random failure in test/correctness/lossless_cast.cpp in another PR for seed 826708018. This PR fixes this case by adding a compile-time check that the shift is in-range. For the examples in test/correctness/intrinsics.cpp the shift amount ends up in a let, so making this work on those cases required handling a TODO: tracking the constant integer bounds of variables in scope in FindIntrinsics and therefore also in lossless_cast. --- src/FindIntrinsics.cpp | 63 ++++++++++++++++++++++++++++++--- src/IROperator.cpp | 57 +++++++++++++++-------------- src/IROperator.h | 19 ++++++---- test/correctness/intrinsics.cpp | 18 +++++----- 4 files changed, 109 insertions(+), 48 deletions(-) diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index 4034d7aea629..7d2e522a9937 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -101,11 +101,10 @@ class FindIntrinsics : public IRMutator { Scope let_var_bounds; Expr lossless_cast(Type t, const Expr &e) { - return Halide::Internal::lossless_cast(t, e, &bounds_cache); + return Halide::Internal::lossless_cast(t, e, let_var_bounds, &bounds_cache); } ConstantInterval constant_integer_bounds(const Expr &e) { - // TODO: Use the scope - add let visitors return Halide::Internal::constant_integer_bounds(e, let_var_bounds, &bounds_cache); } @@ -210,6 +209,51 @@ class FindIntrinsics : public IRMutator { return Expr(); } + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { + struct Frame { + const LetOrLetStmt *orig; + Expr new_value; + ScopedBinding bind; + Frame(const LetOrLetStmt *orig, + Expr &&new_value, + ScopedBinding &&bind) + : orig(orig), new_value(std::move(new_value)), bind(std::move(bind)) { + } + }; + std::vector frames; + decltype(op->body) body; + while (op) { + Expr v = mutate(op->value); + ConstantInterval b = constant_integer_bounds(v); + frames.emplace_back(op, + std::move(v), + ScopedBinding(let_var_bounds, op->name, b)); + body = op->body; + op = body.template as(); + } + + body = mutate(body); + + for (const auto &f : reverse_view(frames)) { + if (f.new_value.same_as(f.orig->value) && body.same_as(f.orig->body)) { + body = f.orig; + } else { + body = LetOrLetStmt::make(f.orig->name, f.new_value, body); + } + } + + return body; + } + + Expr visit(const Let *op) override { + return visit_let(op); + } + + Stmt visit(const LetStmt *op) override { + return visit_let(op); + } + Expr visit(const Add *op) override { if (!find_intrinsics_for_type(op->type)) { return IRMutator::visit(op); @@ -697,7 +741,12 @@ class FindIntrinsics : public IRMutator { bool is_saturated = op->value.as() || op->value.as(); Expr a = lossless_cast(op->type, shift->args[0]); Expr b = lossless_cast(op->type.with_code(shift->args[1].type().code()), shift->args[1]); - if (a.defined() && b.defined()) { + // Doing the shift in the narrower type might introduce UB where + // there was no UB before, so we need to make sure b is bounded. + auto b_bounds = constant_integer_bounds(b); + const int max_shift = op->type.bits() - 1; + + if (a.defined() && b.defined() && b_bounds >= -max_shift && b_bounds <= max_shift) { if (!is_saturated || (shift->is_intrinsic(Call::rounding_shift_right) && can_prove(b >= 0)) || (shift->is_intrinsic(Call::rounding_shift_left) && can_prove(b <= 0))) { @@ -1118,8 +1167,12 @@ class SubstituteInWideningLets : public IRMutator { std::string name; Expr new_value; ScopedBinding bind; - Frame(const std::string &name, const Expr &new_value, ScopedBinding &&bind) - : name(name), new_value(new_value), bind(std::move(bind)) { + Frame(const std::string &name, + const Expr &new_value, + ScopedBinding &&bind) + : name(name), + new_value(new_value), + bind(std::move(bind)) { } }; std::vector frames; diff --git a/src/IROperator.cpp b/src/IROperator.cpp index 26869c4d8b15..e2143470e497 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -427,17 +427,20 @@ Expr const_false(int w) { return make_zero(UInt(1, w)); } -Expr lossless_cast(Type t, Expr e, std::map *cache) { +Expr lossless_cast(Type t, + Expr e, + const Scope &scope, + std::map *cache) { if (!e.defined() || t == e.type()) { return e; } else if (t.can_represent(e.type())) { return cast(t, std::move(e)); } else if (const Cast *c = e.as()) { if (c->type.can_represent(c->value.type())) { - return lossless_cast(t, c->value, cache); + return lossless_cast(t, c->value, scope, cache); } } else if (const Broadcast *b = e.as()) { - Expr v = lossless_cast(t.element_of(), b->value, cache); + Expr v = lossless_cast(t.element_of(), b->value, scope, cache); if (v.defined()) { return Broadcast::make(v, b->lanes); } @@ -456,7 +459,7 @@ Expr lossless_cast(Type t, Expr e, std::map } else if (const Shuffle *shuf = e.as()) { std::vector vecs; for (const auto &vec : shuf->vectors) { - vecs.emplace_back(lossless_cast(t.with_lanes(vec.type().lanes()), vec, cache)); + vecs.emplace_back(lossless_cast(t.with_lanes(vec.type().lanes()), vec, scope, cache)); if (!vecs.back().defined()) { return Expr(); } @@ -465,73 +468,73 @@ Expr lossless_cast(Type t, Expr e, std::map } else if (t.is_int_or_uint()) { // Check the bounds. If they're small enough, we can throw narrowing // casts around e, or subterms. - ConstantInterval ci = constant_integer_bounds(e, Scope::empty_scope(), cache); + ConstantInterval ci = constant_integer_bounds(e, scope, cache); if (t.can_represent(ci)) { // There are certain IR nodes where if the result is expressible // using some type, and the args are expressible using that type, // then the operation can just be done in that type. if (const Add *op = e.as()) { - Expr a = lossless_cast(t, op->a, cache); - Expr b = lossless_cast(t, op->b, cache); + Expr a = lossless_cast(t, op->a, scope, cache); + Expr b = lossless_cast(t, op->b, scope, cache); if (a.defined() && b.defined()) { return Add::make(a, b); } } else if (const Sub *op = e.as()) { - Expr a = lossless_cast(t, op->a, cache); - Expr b = lossless_cast(t, op->b, cache); + Expr a = lossless_cast(t, op->a, scope, cache); + Expr b = lossless_cast(t, op->b, scope, cache); if (a.defined() && b.defined()) { return Sub::make(a, b); } } else if (const Mul *op = e.as()) { - Expr a = lossless_cast(t, op->a, cache); - Expr b = lossless_cast(t, op->b, cache); + Expr a = lossless_cast(t, op->a, scope, cache); + Expr b = lossless_cast(t, op->b, scope, cache); if (a.defined() && b.defined()) { return Mul::make(a, b); } } else if (const Min *op = e.as()) { - Expr a = lossless_cast(t, op->a, cache); - Expr b = lossless_cast(t, op->b, cache); + Expr a = lossless_cast(t, op->a, scope, cache); + Expr b = lossless_cast(t, op->b, scope, cache); if (a.defined() && b.defined()) { debug(0) << a << " " << b << "\n"; return Min::make(a, b); } } else if (const Max *op = e.as()) { - Expr a = lossless_cast(t, op->a, cache); - Expr b = lossless_cast(t, op->b, cache); + Expr a = lossless_cast(t, op->a, scope, cache); + Expr b = lossless_cast(t, op->b, scope, cache); if (a.defined() && b.defined()) { return Max::make(a, b); } } else if (const Mod *op = e.as()) { - Expr a = lossless_cast(t, op->a, cache); - Expr b = lossless_cast(t, op->b, cache); + Expr a = lossless_cast(t, op->a, scope, cache); + Expr b = lossless_cast(t, op->b, scope, cache); if (a.defined() && b.defined()) { return Mod::make(a, b); } } else if (const Call *op = Call::as_intrinsic(e, {Call::widening_add, Call::widen_right_add})) { - Expr a = lossless_cast(t, op->args[0], cache); - Expr b = lossless_cast(t, op->args[1], cache); + Expr a = lossless_cast(t, op->args[0], scope, cache); + Expr b = lossless_cast(t, op->args[1], scope, cache); if (a.defined() && b.defined()) { return Add::make(a, b); } } else if (const Call *op = Call::as_intrinsic(e, {Call::widening_sub, Call::widen_right_sub})) { - Expr a = lossless_cast(t, op->args[0], cache); - Expr b = lossless_cast(t, op->args[1], cache); + Expr a = lossless_cast(t, op->args[0], scope, cache); + Expr b = lossless_cast(t, op->args[1], scope, cache); if (a.defined() && b.defined()) { return Sub::make(a, b); } } else if (const Call *op = Call::as_intrinsic(e, {Call::widening_mul, Call::widen_right_mul})) { - Expr a = lossless_cast(t, op->args[0], cache); - Expr b = lossless_cast(t, op->args[1], cache); + Expr a = lossless_cast(t, op->args[0], scope, cache); + Expr b = lossless_cast(t, op->args[1], scope, cache); if (a.defined() && b.defined()) { return Mul::make(a, b); } } else if (const Call *op = Call::as_intrinsic(e, {Call::shift_left, Call::widening_shift_left, Call::shift_right, Call::widening_shift_right})) { - Expr a = lossless_cast(t, op->args[0], cache); - Expr b = lossless_cast(t, op->args[1], cache); + Expr a = lossless_cast(t, op->args[0], scope, cache); + Expr b = lossless_cast(t, op->args[1], scope, cache); if (a.defined() && b.defined()) { - ConstantInterval cb = constant_integer_bounds(b, Scope::empty_scope(), cache); + ConstantInterval cb = constant_integer_bounds(b, scope, cache); if (cb > -t.bits() && cb < t.bits()) { if (op->is_intrinsic({Call::shift_left, Call::widening_shift_left})) { return a << b; @@ -544,7 +547,7 @@ Expr lossless_cast(Type t, Expr e, std::map if (op->op == VectorReduce::Add || op->op == VectorReduce::Min || op->op == VectorReduce::Max) { - Expr v = lossless_cast(t.with_lanes(op->value.type().lanes()), op->value, cache); + Expr v = lossless_cast(t.with_lanes(op->value.type().lanes()), op->value, scope, cache); if (v.defined()) { return VectorReduce::make(op->op, v, op->type.lanes()); } diff --git a/src/IROperator.h b/src/IROperator.h index 0db5606f011c..34d859521111 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -11,7 +11,9 @@ #include #include +#include "ConstantInterval.h" #include "Expr.h" +#include "Scope.h" #include "Target.h" #include "Tuple.h" @@ -150,13 +152,16 @@ Expr const_false(int lanes = 1); /** Attempt to cast an expression to a smaller type while provably not losing * information. If it can't be done, return an undefined Expr. * - * Optionally accepts a map that gives the constant bounds of exprs already - * analyzed to avoid redoing work across many calls to lossless_cast. It is not - * safe to use this optional map in contexts where the same Expr object may - * take on a different value. For example: - * (let x = 4 in some_expr_object) + (let x = 5 in the_same_expr_object)). - * It is safe to use it after uniquify_variable_names has been run. */ -Expr lossless_cast(Type t, Expr e, std::map *cache = nullptr); + * Optionally accepts a scope giving the constant bounds of any variables, and a + * map that gives the constant bounds of exprs already analyzed to avoid redoing + * work across many calls to lossless_cast. It is not safe to use this optional + * map in contexts where the same Expr object may take on a different value. For + * example: (let x = 4 in some_expr_object) + (let x = 5 in + * the_same_expr_object)). It is safe to use it after uniquify_variable_names + * has been run. */ +Expr lossless_cast(Type t, Expr e, + const Scope &scope = Scope::empty_scope(), + std::map *cache = nullptr); /** Attempt to negate x without introducing new IR and without overflow. * If it can't be done, return an undefined Expr. */ diff --git a/test/correctness/intrinsics.cpp b/test/correctness/intrinsics.cpp index e5119bd5e1be..757e1360be36 100644 --- a/test/correctness/intrinsics.cpp +++ b/test/correctness/intrinsics.cpp @@ -255,17 +255,17 @@ int main(int argc, char **argv) { check((u64(u32x) + 8) / 16, u64(rounding_shift_right(u32x, 4))); check(u16(min((u64(u32x) + 8) / 16, 65535)), u16_sat(rounding_shift_right(u32x, 4))); - // And with variable shifts. - check(i8(widening_add(i8x, (i8(1) << u8y) / 2) >> u8y), rounding_shift_right(i8x, u8y)); + // And with variable shifts. These won't match unless Halide can statically + // prove it's not an out-of-range shift. + Expr u8yc = Min::make(u8y, make_const(u8y.type(), 7)); + Expr i8yc = Max::make(Min::make(i8y, make_const(i8y.type(), 7)), make_const(i8y.type(), -7)); + check(i8(widening_add(i8x, (i8(1) << u8yc) / 2) >> u8yc), rounding_shift_right(i8x, u8yc)); check((i32x + (i32(1) << u32y) / 2) >> u32y, rounding_shift_right(i32x, u32y)); - - check(i8(widening_add(i8x, (i8(1) << max(i8y, 0)) / 2) >> i8y), rounding_shift_right(i8x, i8y)); + check(i8(widening_add(i8x, (i8(1) << max(i8yc, 0)) / 2) >> i8yc), rounding_shift_right(i8x, i8yc)); check((i32x + (i32(1) << max(i32y, 0)) / 2) >> i32y, rounding_shift_right(i32x, i32y)); - - check(i8(widening_add(i8x, (i8(1) >> min(i8y, 0)) / 2) << i8y), rounding_shift_left(i8x, i8y)); + check(i8(widening_add(i8x, (i8(1) >> min(i8yc, 0)) / 2) << i8yc), rounding_shift_left(i8x, i8yc)); check((i32x + (i32(1) >> min(i32y, 0)) / 2) << i32y, rounding_shift_left(i32x, i32y)); - - check(i8(widening_add(i8x, (i8(1) << -min(i8y, 0)) / 2) << i8y), rounding_shift_left(i8x, i8y)); + check(i8(widening_add(i8x, (i8(1) << -min(i8yc, 0)) / 2) << i8yc), rounding_shift_left(i8x, i8yc)); check((i32x + (i32(1) << -min(i32y, 0)) / 2) << i32y, rounding_shift_left(i32x, i32y)); check((i32x + (i32(1) << max(-i32y, 0)) / 2) << i32y, rounding_shift_left(i32x, i32y)); @@ -372,7 +372,7 @@ int main(int argc, char **argv) { f(x) = cast(x); f.compute_root(); - g(x) = rounding_shift_right(x, 0) + rounding_shift_left(x, 8); + g(x) = rounding_shift_right(f(x), 0) + u8(rounding_shift_left(u16(f(x)), 11)); g.compile_jit(); }