Skip to content

Commit

Permalink
Fix UB-introducing rewrite in FindIntrinsics (#8539)
Browse files Browse the repository at this point in the history
FindIntrinsics was rewriting

i8(rounding_shift_left(i16(foo_i8), 11))

to

rounding_shift_left(foo_i8, 11)

I.e. it decided it could do the shift in the narrower type. However this
isn't correct, because 11 is a valid shift for a 16-bit int, but not for
an 8-bit int. The former is zero, the latter gets turned into a poison
value because we lower it to llvm's shl. This was discovered by a random
failure in test/correctness/lossless_cast.cpp in another PR for seed
826708018.

This PR fixes this case by adding a compile-time check that the shift is
in-range. For the examples in test/correctness/intrinsics.cpp the shift
amount ends up in a let, so making this work on those cases required
handling a TODO: tracking the constant integer bounds of variables in
scope in FindIntrinsics and therefore also in lossless_cast.
  • Loading branch information
abadams authored Dec 27, 2024
1 parent 097aee9 commit 5783534
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 48 deletions.
63 changes: 58 additions & 5 deletions src/FindIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,10 @@ class FindIntrinsics : public IRMutator {
Scope<ConstantInterval> let_var_bounds;

Expr lossless_cast(Type t, const Expr &e) {
return Halide::Internal::lossless_cast(t, e, &bounds_cache);
return Halide::Internal::lossless_cast(t, e, let_var_bounds, &bounds_cache);
}

ConstantInterval constant_integer_bounds(const Expr &e) {
// TODO: Use the scope - add let visitors
return Halide::Internal::constant_integer_bounds(e, let_var_bounds, &bounds_cache);
}

Expand Down Expand Up @@ -210,6 +209,51 @@ class FindIntrinsics : public IRMutator {
return Expr();
}

template<typename LetOrLetStmt>
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
struct Frame {
const LetOrLetStmt *orig;
Expr new_value;
ScopedBinding<ConstantInterval> bind;
Frame(const LetOrLetStmt *orig,
Expr &&new_value,
ScopedBinding<ConstantInterval> &&bind)
: orig(orig), new_value(std::move(new_value)), bind(std::move(bind)) {
}
};
std::vector<Frame> frames;
decltype(op->body) body;
while (op) {
Expr v = mutate(op->value);
ConstantInterval b = constant_integer_bounds(v);
frames.emplace_back(op,
std::move(v),
ScopedBinding<ConstantInterval>(let_var_bounds, op->name, b));
body = op->body;
op = body.template as<LetOrLetStmt>();
}

body = mutate(body);

for (const auto &f : reverse_view(frames)) {
if (f.new_value.same_as(f.orig->value) && body.same_as(f.orig->body)) {
body = f.orig;
} else {
body = LetOrLetStmt::make(f.orig->name, f.new_value, body);
}
}

return body;
}

Expr visit(const Let *op) override {
return visit_let(op);
}

Stmt visit(const LetStmt *op) override {
return visit_let(op);
}

Expr visit(const Add *op) override {
if (!find_intrinsics_for_type(op->type)) {
return IRMutator::visit(op);
Expand Down Expand Up @@ -697,7 +741,12 @@ class FindIntrinsics : public IRMutator {
bool is_saturated = op->value.as<Max>() || op->value.as<Min>();
Expr a = lossless_cast(op->type, shift->args[0]);
Expr b = lossless_cast(op->type.with_code(shift->args[1].type().code()), shift->args[1]);
if (a.defined() && b.defined()) {
// Doing the shift in the narrower type might introduce UB where
// there was no UB before, so we need to make sure b is bounded.
auto b_bounds = constant_integer_bounds(b);
const int max_shift = op->type.bits() - 1;

if (a.defined() && b.defined() && b_bounds >= -max_shift && b_bounds <= max_shift) {
if (!is_saturated ||
(shift->is_intrinsic(Call::rounding_shift_right) && can_prove(b >= 0)) ||
(shift->is_intrinsic(Call::rounding_shift_left) && can_prove(b <= 0))) {
Expand Down Expand Up @@ -1118,8 +1167,12 @@ class SubstituteInWideningLets : public IRMutator {
std::string name;
Expr new_value;
ScopedBinding<Expr> bind;
Frame(const std::string &name, const Expr &new_value, ScopedBinding<Expr> &&bind)
: name(name), new_value(new_value), bind(std::move(bind)) {
Frame(const std::string &name,
const Expr &new_value,
ScopedBinding<Expr> &&bind)
: name(name),
new_value(new_value),
bind(std::move(bind)) {
}
};
std::vector<Frame> frames;
Expand Down
57 changes: 30 additions & 27 deletions src/IROperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,17 +427,20 @@ Expr const_false(int w) {
return make_zero(UInt(1, w));
}

Expr lossless_cast(Type t, Expr e, std::map<Expr, ConstantInterval, ExprCompare> *cache) {
Expr lossless_cast(Type t,
Expr e,
const Scope<ConstantInterval> &scope,
std::map<Expr, ConstantInterval, ExprCompare> *cache) {
if (!e.defined() || t == e.type()) {
return e;
} else if (t.can_represent(e.type())) {
return cast(t, std::move(e));
} else if (const Cast *c = e.as<Cast>()) {
if (c->type.can_represent(c->value.type())) {
return lossless_cast(t, c->value, cache);
return lossless_cast(t, c->value, scope, cache);
}
} else if (const Broadcast *b = e.as<Broadcast>()) {
Expr v = lossless_cast(t.element_of(), b->value, cache);
Expr v = lossless_cast(t.element_of(), b->value, scope, cache);
if (v.defined()) {
return Broadcast::make(v, b->lanes);
}
Expand All @@ -456,7 +459,7 @@ Expr lossless_cast(Type t, Expr e, std::map<Expr, ConstantInterval, ExprCompare>
} else if (const Shuffle *shuf = e.as<Shuffle>()) {
std::vector<Expr> vecs;
for (const auto &vec : shuf->vectors) {
vecs.emplace_back(lossless_cast(t.with_lanes(vec.type().lanes()), vec, cache));
vecs.emplace_back(lossless_cast(t.with_lanes(vec.type().lanes()), vec, scope, cache));
if (!vecs.back().defined()) {
return Expr();
}
Expand All @@ -465,73 +468,73 @@ Expr lossless_cast(Type t, Expr e, std::map<Expr, ConstantInterval, ExprCompare>
} else if (t.is_int_or_uint()) {
// Check the bounds. If they're small enough, we can throw narrowing
// casts around e, or subterms.
ConstantInterval ci = constant_integer_bounds(e, Scope<ConstantInterval>::empty_scope(), cache);
ConstantInterval ci = constant_integer_bounds(e, scope, cache);

if (t.can_represent(ci)) {
// There are certain IR nodes where if the result is expressible
// using some type, and the args are expressible using that type,
// then the operation can just be done in that type.
if (const Add *op = e.as<Add>()) {
Expr a = lossless_cast(t, op->a, cache);
Expr b = lossless_cast(t, op->b, cache);
Expr a = lossless_cast(t, op->a, scope, cache);
Expr b = lossless_cast(t, op->b, scope, cache);
if (a.defined() && b.defined()) {
return Add::make(a, b);
}
} else if (const Sub *op = e.as<Sub>()) {
Expr a = lossless_cast(t, op->a, cache);
Expr b = lossless_cast(t, op->b, cache);
Expr a = lossless_cast(t, op->a, scope, cache);
Expr b = lossless_cast(t, op->b, scope, cache);
if (a.defined() && b.defined()) {
return Sub::make(a, b);
}
} else if (const Mul *op = e.as<Mul>()) {
Expr a = lossless_cast(t, op->a, cache);
Expr b = lossless_cast(t, op->b, cache);
Expr a = lossless_cast(t, op->a, scope, cache);
Expr b = lossless_cast(t, op->b, scope, cache);
if (a.defined() && b.defined()) {
return Mul::make(a, b);
}
} else if (const Min *op = e.as<Min>()) {
Expr a = lossless_cast(t, op->a, cache);
Expr b = lossless_cast(t, op->b, cache);
Expr a = lossless_cast(t, op->a, scope, cache);
Expr b = lossless_cast(t, op->b, scope, cache);
if (a.defined() && b.defined()) {
debug(0) << a << " " << b << "\n";
return Min::make(a, b);
}
} else if (const Max *op = e.as<Max>()) {
Expr a = lossless_cast(t, op->a, cache);
Expr b = lossless_cast(t, op->b, cache);
Expr a = lossless_cast(t, op->a, scope, cache);
Expr b = lossless_cast(t, op->b, scope, cache);
if (a.defined() && b.defined()) {
return Max::make(a, b);
}
} else if (const Mod *op = e.as<Mod>()) {
Expr a = lossless_cast(t, op->a, cache);
Expr b = lossless_cast(t, op->b, cache);
Expr a = lossless_cast(t, op->a, scope, cache);
Expr b = lossless_cast(t, op->b, scope, cache);
if (a.defined() && b.defined()) {
return Mod::make(a, b);
}
} else if (const Call *op = Call::as_intrinsic(e, {Call::widening_add, Call::widen_right_add})) {
Expr a = lossless_cast(t, op->args[0], cache);
Expr b = lossless_cast(t, op->args[1], cache);
Expr a = lossless_cast(t, op->args[0], scope, cache);
Expr b = lossless_cast(t, op->args[1], scope, cache);
if (a.defined() && b.defined()) {
return Add::make(a, b);
}
} else if (const Call *op = Call::as_intrinsic(e, {Call::widening_sub, Call::widen_right_sub})) {
Expr a = lossless_cast(t, op->args[0], cache);
Expr b = lossless_cast(t, op->args[1], cache);
Expr a = lossless_cast(t, op->args[0], scope, cache);
Expr b = lossless_cast(t, op->args[1], scope, cache);
if (a.defined() && b.defined()) {
return Sub::make(a, b);
}
} else if (const Call *op = Call::as_intrinsic(e, {Call::widening_mul, Call::widen_right_mul})) {
Expr a = lossless_cast(t, op->args[0], cache);
Expr b = lossless_cast(t, op->args[1], cache);
Expr a = lossless_cast(t, op->args[0], scope, cache);
Expr b = lossless_cast(t, op->args[1], scope, cache);
if (a.defined() && b.defined()) {
return Mul::make(a, b);
}
} else if (const Call *op = Call::as_intrinsic(e, {Call::shift_left, Call::widening_shift_left,
Call::shift_right, Call::widening_shift_right})) {
Expr a = lossless_cast(t, op->args[0], cache);
Expr b = lossless_cast(t, op->args[1], cache);
Expr a = lossless_cast(t, op->args[0], scope, cache);
Expr b = lossless_cast(t, op->args[1], scope, cache);
if (a.defined() && b.defined()) {
ConstantInterval cb = constant_integer_bounds(b, Scope<ConstantInterval>::empty_scope(), cache);
ConstantInterval cb = constant_integer_bounds(b, scope, cache);
if (cb > -t.bits() && cb < t.bits()) {
if (op->is_intrinsic({Call::shift_left, Call::widening_shift_left})) {
return a << b;
Expand All @@ -544,7 +547,7 @@ Expr lossless_cast(Type t, Expr e, std::map<Expr, ConstantInterval, ExprCompare>
if (op->op == VectorReduce::Add ||
op->op == VectorReduce::Min ||
op->op == VectorReduce::Max) {
Expr v = lossless_cast(t.with_lanes(op->value.type().lanes()), op->value, cache);
Expr v = lossless_cast(t.with_lanes(op->value.type().lanes()), op->value, scope, cache);
if (v.defined()) {
return VectorReduce::make(op->op, v, op->type.lanes());
}
Expand Down
19 changes: 12 additions & 7 deletions src/IROperator.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
#include <map>
#include <optional>

#include "ConstantInterval.h"
#include "Expr.h"
#include "Scope.h"
#include "Target.h"
#include "Tuple.h"

Expand Down Expand Up @@ -150,13 +152,16 @@ Expr const_false(int lanes = 1);
/** Attempt to cast an expression to a smaller type while provably not losing
* information. If it can't be done, return an undefined Expr.
*
* Optionally accepts a map that gives the constant bounds of exprs already
* analyzed to avoid redoing work across many calls to lossless_cast. It is not
* safe to use this optional map in contexts where the same Expr object may
* take on a different value. For example:
* (let x = 4 in some_expr_object) + (let x = 5 in the_same_expr_object)).
* It is safe to use it after uniquify_variable_names has been run. */
Expr lossless_cast(Type t, Expr e, std::map<Expr, ConstantInterval, ExprCompare> *cache = nullptr);
* Optionally accepts a scope giving the constant bounds of any variables, and a
* map that gives the constant bounds of exprs already analyzed to avoid redoing
* work across many calls to lossless_cast. It is not safe to use this optional
* map in contexts where the same Expr object may take on a different value. For
* example: (let x = 4 in some_expr_object) + (let x = 5 in
* the_same_expr_object)). It is safe to use it after uniquify_variable_names
* has been run. */
Expr lossless_cast(Type t, Expr e,
const Scope<ConstantInterval> &scope = Scope<ConstantInterval>::empty_scope(),
std::map<Expr, ConstantInterval, ExprCompare> *cache = nullptr);

/** Attempt to negate x without introducing new IR and without overflow.
* If it can't be done, return an undefined Expr. */
Expand Down
18 changes: 9 additions & 9 deletions test/correctness/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,17 +255,17 @@ int main(int argc, char **argv) {
check((u64(u32x) + 8) / 16, u64(rounding_shift_right(u32x, 4)));
check(u16(min((u64(u32x) + 8) / 16, 65535)), u16_sat(rounding_shift_right(u32x, 4)));

// And with variable shifts.
check(i8(widening_add(i8x, (i8(1) << u8y) / 2) >> u8y), rounding_shift_right(i8x, u8y));
// And with variable shifts. These won't match unless Halide can statically
// prove it's not an out-of-range shift.
Expr u8yc = Min::make(u8y, make_const(u8y.type(), 7));
Expr i8yc = Max::make(Min::make(i8y, make_const(i8y.type(), 7)), make_const(i8y.type(), -7));
check(i8(widening_add(i8x, (i8(1) << u8yc) / 2) >> u8yc), rounding_shift_right(i8x, u8yc));
check((i32x + (i32(1) << u32y) / 2) >> u32y, rounding_shift_right(i32x, u32y));

check(i8(widening_add(i8x, (i8(1) << max(i8y, 0)) / 2) >> i8y), rounding_shift_right(i8x, i8y));
check(i8(widening_add(i8x, (i8(1) << max(i8yc, 0)) / 2) >> i8yc), rounding_shift_right(i8x, i8yc));
check((i32x + (i32(1) << max(i32y, 0)) / 2) >> i32y, rounding_shift_right(i32x, i32y));

check(i8(widening_add(i8x, (i8(1) >> min(i8y, 0)) / 2) << i8y), rounding_shift_left(i8x, i8y));
check(i8(widening_add(i8x, (i8(1) >> min(i8yc, 0)) / 2) << i8yc), rounding_shift_left(i8x, i8yc));
check((i32x + (i32(1) >> min(i32y, 0)) / 2) << i32y, rounding_shift_left(i32x, i32y));

check(i8(widening_add(i8x, (i8(1) << -min(i8y, 0)) / 2) << i8y), rounding_shift_left(i8x, i8y));
check(i8(widening_add(i8x, (i8(1) << -min(i8yc, 0)) / 2) << i8yc), rounding_shift_left(i8x, i8yc));
check((i32x + (i32(1) << -min(i32y, 0)) / 2) << i32y, rounding_shift_left(i32x, i32y));
check((i32x + (i32(1) << max(-i32y, 0)) / 2) << i32y, rounding_shift_left(i32x, i32y));

Expand Down Expand Up @@ -372,7 +372,7 @@ int main(int argc, char **argv) {
f(x) = cast<uint8_t>(x);
f.compute_root();

g(x) = rounding_shift_right(x, 0) + rounding_shift_left(x, 8);
g(x) = rounding_shift_right(f(x), 0) + u8(rounding_shift_left(u16(f(x)), 11));

g.compile_jit();
}
Expand Down

0 comments on commit 5783534

Please sign in to comment.