Skip to content

Commit

Permalink
Don't strip strict_float() from lets (#5871)
Browse files Browse the repository at this point in the history
* Don't strip strict_float() from lets

Bug injected in #5856: the change in Simplify_Let.cpp was inadvertently stripping `strict_float()` calls that wrapped the RHS of a Let-expr, which can change results nontrivially in some cases. I don't think a new test for this fix is practical -- it would be a little fragile, as it would rely on the specifics of simplification that could change over time.

As a drive-by, also added an explicit rule to Simplify_Call to ensure that strict_float(strict_float(x)) -> strict_float(x) in *all* cases. (The existing rule didn't do this in all cases.)
  • Loading branch information
steven-johnson authored Mar 31, 2021
1 parent 896b260 commit cb78a6b
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 6 deletions.
16 changes: 11 additions & 5 deletions src/Simplify_Call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,18 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) {
}

if (op->is_intrinsic(Call::strict_float)) {
ScopedValue<bool> save_no_float_simplify(no_float_simplify, true);
Expr arg = mutate(op->args[0], nullptr);
if (arg.same_as(op->args[0])) {
return op;
if (Call::as_intrinsic(op->args[0], {Call::strict_float})) {
// Always simplify strict_float(strict_float(x)) -> strict_float(x).
Expr arg = mutate(op->args[0], nullptr);
return arg.same_as(op->args[0]) ? op->args[0] : arg;
} else {
return strict_float(arg);
ScopedValue<bool> save_no_float_simplify(no_float_simplify, true);
Expr arg = mutate(op->args[0], nullptr);
if (arg.same_as(op->args[0])) {
return op;
} else {
return strict_float(arg);
}
}
} else if (op->is_intrinsic(Call::popcount) ||
op->is_intrinsic(Call::count_leading_zeros) ||
Expand Down
5 changes: 4 additions & 1 deletion src/Simplify_Let.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) {
const Shuffle *shuffle = f.new_value.template as<Shuffle>();
const Variable *var_b = nullptr;
const Variable *var_a = nullptr;
const Call *tag = nullptr;

if (add) {
var_a = add->a.as<Variable>();
Expand Down Expand Up @@ -174,7 +175,9 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) {
Expr op_b = var_a ? new_var : shuffle->vectors[1];
replacement = substitute(f.new_name, Shuffle::make_concat({op_a, op_b}), replacement);
f.new_value = var_a ? shuffle->vectors[1] : shuffle->vectors[0];
} else if (const Call *tag = Call::as_tag(f.new_value)) {
} else if ((tag = Call::as_tag(f.new_value)) != nullptr && !tag->is_intrinsic(Call::strict_float)) {
// Most tags should be stripped here, but not strict_float(); removing it will change the semantics
// of the let-expr we are producing.
replacement = substitute(f.new_name, Call::make(tag->type, tag->name, {new_var}, Call::PureIntrinsic), replacement);
f.new_value = tag->args[0];
} else {
Expand Down
2 changes: 2 additions & 0 deletions test/correctness/simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1763,6 +1763,8 @@ void check_math() {
check(Halide::trunc(-1.6f), -1.0f);
check(Halide::floor(round(x)), round(x));
check(Halide::ceil(ceil(x)), ceil(x));

check(strict_float(strict_float(x)), strict_float(x));
}

void check_overflow() {
Expand Down

0 comments on commit cb78a6b

Please sign in to comment.