Skip to content

Commit

Permalink
Use lossless_cast for saturating casts from unsigned to signed on x86 (
Browse files Browse the repository at this point in the history
…#8527)

* Use lossless_cast to handle saturating casts from unsigned to signed in x86

This takes care of a TODO. I also moved that block of code into an if
statement that only considers saturating casts, to avoid wasting time on
it for all other call nodes.

* Fix incorrect vector width in test
  • Loading branch information
abadams authored Dec 20, 2024
1 parent bc9dfbf commit 526364f
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 43 deletions.
81 changes: 38 additions & 43 deletions src/CodeGen_X86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,6 @@ void CodeGen_X86::visit(const Call *op) {
Expr pattern;
};

// clang-format off
static const Pattern patterns[] = {
{"pmulh", mul_shift_right(wild_i16x_, wild_i16x_, 16)},
{"pmulh", mul_shift_right(wild_u16x_, wild_u16x_, 16)},
Expand All @@ -656,7 +655,6 @@ void CodeGen_X86::visit(const Call *op) {
{"saturating_narrow", i8_sat(wild_i16x_)},
{"saturating_narrow", u8_sat(wild_i16x_)},
};
// clang-format on

vector<Expr> matches;
for (const auto &pattern : patterns) {
Expand All @@ -668,52 +666,49 @@ void CodeGen_X86::visit(const Call *op) {
}
}

// clang-format off
static const Pattern reinterpret_patterns[] = {
{"saturating_narrow", i16_sat(wild_u32x_)},
{"saturating_narrow", u16_sat(wild_u32x_)},
{"saturating_narrow", i8_sat(wild_u16x_)},
{"saturating_narrow", u8_sat(wild_u16x_)},
};
// clang-format on
if (op->is_intrinsic(Call::saturating_cast)) {

// Search for saturating casts where the inner value can be
// reinterpreted to signed, so that we can use existing
// saturating_narrow instructions.
// TODO: should use lossless_cast once it is fixed.
for (const auto &pattern : reinterpret_patterns) {
if (expr_match(pattern.pattern, op, matches)) {
const Expr &expr = matches[0];
const Type &t = expr.type();
// TODO(8212): might want to keep track of scope of bounds information.
const ConstantInterval ibounds = constant_integer_bounds(expr);
const Type reint_type = t.with_code(halide_type_int);
// If the signed type can represent the maximum value unsigned value,
// we can safely reinterpret this unsigned expression as signed.
if (reint_type.can_represent(ibounds)) {
// Can safely reinterpret to signed integer.
matches[0] = cast(reint_type, matches[0]);
value = call_overloaded_intrin(op->type, pattern.intrin, matches);
if (value) {
return;
static const Pattern reinterpret_patterns[] = {
{"saturating_narrow", i16_sat(wild_u32x_)},
{"saturating_narrow", u16_sat(wild_u32x_)},
{"saturating_narrow", i8_sat(wild_u16x_)},
{"saturating_narrow", u8_sat(wild_u16x_)},
};

// Search for saturating casts where the inner value can be
// reinterpreted to signed, so that we can use existing
// saturating_narrow instructions.
for (const auto &pattern : reinterpret_patterns) {
if (expr_match(pattern.pattern, op, matches)) {
const Type signed_type = matches[0].type().with_code(halide_type_int);
Expr e = lossless_cast(signed_type, matches[0]);
if (e.defined()) {
// Can safely reinterpret to signed integer.
matches[0] = e;
value = call_overloaded_intrin(op->type, pattern.intrin, matches);
if (value) {
return;
}
}
// No reinterpret patterns match the same input, so stop matching.
break;
}
// No reinterpret patterns match the same input, so stop matching.
break;
}
}

static const vector<pair<Expr, Expr>> cast_rewrites = {
// Some double-narrowing saturating casts can be better expressed as
// combinations of single-narrowing saturating casts.
{u8_sat(wild_i32x_), u8_sat(i16_sat(wild_i32x_))},
{i8_sat(wild_i32x_), i8_sat(i16_sat(wild_i32x_))},
};
for (const auto &i : cast_rewrites) {
if (expr_match(i.first, op, matches)) {
Expr replacement = substitute("*", matches[0], with_lanes(i.second, op->type.lanes()));
value = codegen(replacement);
return;
static const vector<pair<Expr, Expr>> cast_rewrites = {
// Some double-narrowing saturating casts can be better expressed as
// combinations of single-narrowing saturating casts.
{u8_sat(wild_i32x_), u8_sat(i16_sat(wild_i32x_))},
{i8_sat(wild_i32x_), i8_sat(i16_sat(wild_i32x_))},
{i8_sat(wild_u32x_), i8_sat(i16_sat(wild_u32x_))},
};

for (const auto &i : cast_rewrites) {
if (expr_match(i.first, op, matches)) {
Expr replacement = substitute("*", matches[0], with_lanes(i.second, op->type.lanes()));
value = codegen(replacement);
return;
}
}
}

Expand Down
6 changes: 6 additions & 0 deletions test/correctness/simd_op_check_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,12 @@ class SimdOpCheckX86 : public SimdOpCheckTest {
check(std::string("packssdw") + check_suffix, 8 * w, u8_sat(i32_1));
check(std::string("packssdw") + check_suffix, 8 * w, i8_sat(i32_1));

// A uint without the top bit set can be reinterpreted as an int
// so that packssdw can be used.
check(std::string("packssdw") + check_suffix, 4 * w, i16_sat(u32_1 >> 1));
check(std::string("packssdw") + check_suffix, 8 * w, i8_sat(u32_1 >> 1));
check(std::string("packsswb") + check_suffix, 8 * w, i8_sat(u16_1 >> 1));

// Sum-of-absolute-difference ops
{
const int f = 8; // reduction factor.
Expand Down

0 comments on commit 526364f

Please sign in to comment.