diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 90609e1477c6..3d3f071b6b0f 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -647,7 +647,6 @@ void CodeGen_X86::visit(const Call *op) { Expr pattern; }; - // clang-format off static const Pattern patterns[] = { {"pmulh", mul_shift_right(wild_i16x_, wild_i16x_, 16)}, {"pmulh", mul_shift_right(wild_u16x_, wild_u16x_, 16)}, @@ -656,7 +655,6 @@ void CodeGen_X86::visit(const Call *op) { {"saturating_narrow", i8_sat(wild_i16x_)}, {"saturating_narrow", u8_sat(wild_i16x_)}, }; - // clang-format on vector matches; for (const auto &pattern : patterns) { @@ -668,52 +666,49 @@ void CodeGen_X86::visit(const Call *op) { } } - // clang-format off - static const Pattern reinterpret_patterns[] = { - {"saturating_narrow", i16_sat(wild_u32x_)}, - {"saturating_narrow", u16_sat(wild_u32x_)}, - {"saturating_narrow", i8_sat(wild_u16x_)}, - {"saturating_narrow", u8_sat(wild_u16x_)}, - }; - // clang-format on + if (op->is_intrinsic(Call::saturating_cast)) { - // Search for saturating casts where the inner value can be - // reinterpreted to signed, so that we can use existing - // saturating_narrow instructions. - // TODO: should use lossless_cast once it is fixed. - for (const auto &pattern : reinterpret_patterns) { - if (expr_match(pattern.pattern, op, matches)) { - const Expr &expr = matches[0]; - const Type &t = expr.type(); - // TODO(8212): might want to keep track of scope of bounds information. - const ConstantInterval ibounds = constant_integer_bounds(expr); - const Type reint_type = t.with_code(halide_type_int); - // If the signed type can represent the maximum value unsigned value, - // we can safely reinterpret this unsigned expression as signed. - if (reint_type.can_represent(ibounds)) { - // Can safely reinterpret to signed integer. - matches[0] = cast(reint_type, matches[0]); - value = call_overloaded_intrin(op->type, pattern.intrin, matches); - if (value) { - return; + static const Pattern reinterpret_patterns[] = { + {"saturating_narrow", i16_sat(wild_u32x_)}, + {"saturating_narrow", u16_sat(wild_u32x_)}, + {"saturating_narrow", i8_sat(wild_u16x_)}, + {"saturating_narrow", u8_sat(wild_u16x_)}, + }; + + // Search for saturating casts where the inner value can be + // reinterpreted to signed, so that we can use existing + // saturating_narrow instructions. + for (const auto &pattern : reinterpret_patterns) { + if (expr_match(pattern.pattern, op, matches)) { + const Type signed_type = matches[0].type().with_code(halide_type_int); + Expr e = lossless_cast(signed_type, matches[0]); + if (e.defined()) { + // Can safely reinterpret to signed integer. + matches[0] = e; + value = call_overloaded_intrin(op->type, pattern.intrin, matches); + if (value) { + return; + } } + // No reinterpret patterns match the same input, so stop matching. + break; } - // No reinterpret patterns match the same input, so stop matching. - break; } - } - static const vector> cast_rewrites = { - // Some double-narrowing saturating casts can be better expressed as - // combinations of single-narrowing saturating casts. - {u8_sat(wild_i32x_), u8_sat(i16_sat(wild_i32x_))}, - {i8_sat(wild_i32x_), i8_sat(i16_sat(wild_i32x_))}, - }; - for (const auto &i : cast_rewrites) { - if (expr_match(i.first, op, matches)) { - Expr replacement = substitute("*", matches[0], with_lanes(i.second, op->type.lanes())); - value = codegen(replacement); - return; + static const vector> cast_rewrites = { + // Some double-narrowing saturating casts can be better expressed as + // combinations of single-narrowing saturating casts. + {u8_sat(wild_i32x_), u8_sat(i16_sat(wild_i32x_))}, + {i8_sat(wild_i32x_), i8_sat(i16_sat(wild_i32x_))}, + {i8_sat(wild_u32x_), i8_sat(i16_sat(wild_u32x_))}, + }; + + for (const auto &i : cast_rewrites) { + if (expr_match(i.first, op, matches)) { + Expr replacement = substitute("*", matches[0], with_lanes(i.second, op->type.lanes())); + value = codegen(replacement); + return; + } } } diff --git a/test/correctness/simd_op_check_x86.cpp b/test/correctness/simd_op_check_x86.cpp index 4a81dfbdf926..e8f61544fe7c 100644 --- a/test/correctness/simd_op_check_x86.cpp +++ b/test/correctness/simd_op_check_x86.cpp @@ -234,6 +234,12 @@ class SimdOpCheckX86 : public SimdOpCheckTest { check(std::string("packssdw") + check_suffix, 8 * w, u8_sat(i32_1)); check(std::string("packssdw") + check_suffix, 8 * w, i8_sat(i32_1)); + // A uint without the top bit set can be reinterpreted as an int + // so that packssdw can be used. + check(std::string("packssdw") + check_suffix, 4 * w, i16_sat(u32_1 >> 1)); + check(std::string("packssdw") + check_suffix, 8 * w, i8_sat(u32_1 >> 1)); + check(std::string("packsswb") + check_suffix, 8 * w, i8_sat(u16_1 >> 1)); + // Sum-of-absolute-difference ops { const int f = 8; // reduction factor.