diff --git a/src/shims/x86/avx2.rs b/src/shims/x86/avx2.rs index 97b9f649c1..7d8e52db73 100644 --- a/src/shims/x86/avx2.rs +++ b/src/shims/x86/avx2.rs @@ -6,7 +6,7 @@ use rustc_target::callconv::FnAbi; use super::{ ShiftOp, horizontal_bin_op, mpsadbw, packssdw, packsswb, packusdw, packuswb, permute, pmaddbw, - pmulhrsw, psadbw, pshufb, psign, shift_simd_by_scalar, + pmaddwd, pmulhrsw, psadbw, pshufb, psign, shift_simd_by_scalar, }; use crate::*; @@ -232,33 +232,7 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { let [left, right] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?; - let (left, left_len) = this.project_to_simd(left)?; - let (right, right_len) = this.project_to_simd(right)?; - let (dest, dest_len) = this.project_to_simd(dest)?; - - assert_eq!(left_len, right_len); - assert_eq!(dest_len.strict_mul(2), left_len); - - for i in 0..dest_len { - let j1 = i.strict_mul(2); - let left1 = this.read_scalar(&this.project_index(&left, j1)?)?.to_i16()?; - let right1 = this.read_scalar(&this.project_index(&right, j1)?)?.to_i16()?; - - let j2 = j1.strict_add(1); - let left2 = this.read_scalar(&this.project_index(&left, j2)?)?.to_i16()?; - let right2 = this.read_scalar(&this.project_index(&right, j2)?)?.to_i16()?; - - let dest = this.project_index(&dest, i)?; - - // Multiplications are i16*i16->i32, which will not overflow. - let mul1 = i32::from(left1).strict_mul(right1.into()); - let mul2 = i32::from(left2).strict_mul(right2.into()); - // However, this addition can overflow in the most extreme case - // (-0x8000)*(-0x8000)+(-0x8000)*(-0x8000) = 0x80000000 - let res = mul1.wrapping_add(mul2); - - this.write_scalar(Scalar::from_i32(res), &dest)?; - } + pmaddwd(this, left, right, dest)?; } _ => return interp_ok(EmulateItemResult::NotSupported), } diff --git a/src/shims/x86/avx512.rs b/src/shims/x86/avx512.rs index 0466ba1bd6..9231fc4469 100644 --- a/src/shims/x86/avx512.rs +++ b/src/shims/x86/avx512.rs @@ -3,7 +3,7 @@ use rustc_middle::ty::Ty; use rustc_span::Symbol; use rustc_target::callconv::FnAbi; -use super::{permute, pmaddbw, psadbw, pshufb}; +use super::{permute, pmaddbw, pmaddwd, psadbw, pshufb}; use crate::*; impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {} @@ -88,6 +88,15 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { psadbw(this, left, right, dest)? } + // Used to implement the _mm512_madd_epi16 function. + "pmaddw.d.512" => { + this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?; + + let [left, right] = + this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?; + + pmaddwd(this, left, right, dest)?; + } // Used to implement the _mm512_maddubs_epi16 function. "pmaddubs.w.512" => { let [left, right] = diff --git a/src/shims/x86/mod.rs b/src/shims/x86/mod.rs index a5164cc87a..dc0d8d48ac 100644 --- a/src/shims/x86/mod.rs +++ b/src/shims/x86/mod.rs @@ -964,6 +964,52 @@ fn psadbw<'tcx>( interp_ok(()) } +/// Multiply packed signed 16-bit integers in `left` and `right`, producing intermediate signed 32-bit integers. +/// Horizontally add adjacent pairs of intermediate 32-bit integers, and pack the results in `dest`. +/// +/// +/// +/// +fn pmaddwd<'tcx>( + ecx: &mut crate::MiriInterpCx<'tcx>, + left: &OpTy<'tcx>, + right: &OpTy<'tcx>, + dest: &MPlaceTy<'tcx>, +) -> InterpResult<'tcx, ()> { + let (left, left_len) = ecx.project_to_simd(left)?; + let (right, right_len) = ecx.project_to_simd(right)?; + let (dest, dest_len) = ecx.project_to_simd(dest)?; + + // fn pmaddwd(a: i16x8, b: i16x8) -> i32x4; + // fn pmaddwd(a: i16x16, b: i16x16) -> i32x8; + // fn vpmaddwd(a: i16x32, b: i16x32) -> i32x16; + assert_eq!(left_len, right_len); + assert_eq!(dest_len.strict_mul(2), left_len); + + for i in 0..dest_len { + let j1 = i.strict_mul(2); + let left1 = ecx.read_scalar(&ecx.project_index(&left, j1)?)?.to_i16()?; + let right1 = ecx.read_scalar(&ecx.project_index(&right, j1)?)?.to_i16()?; + + let j2 = j1.strict_add(1); + let left2 = ecx.read_scalar(&ecx.project_index(&left, j2)?)?.to_i16()?; + let right2 = ecx.read_scalar(&ecx.project_index(&right, j2)?)?.to_i16()?; + + let dest = ecx.project_index(&dest, i)?; + + // Multiplications are i16*i16->i32, which will not overflow. + let mul1 = i32::from(left1).strict_mul(right1.into()); + let mul2 = i32::from(left2).strict_mul(right2.into()); + // However, this addition can overflow in the most extreme case + // (-0x8000)*(-0x8000)+(-0x8000)*(-0x8000) = 0x80000000 + let res = mul1.wrapping_add(mul2); + + ecx.write_scalar(Scalar::from_i32(res), &dest)?; + } + + interp_ok(()) +} + /// Multiplies packed 8-bit unsigned integers from `left` and packed /// signed 8-bit integers from `right` into 16-bit signed integers. Then, /// the saturating sum of the products with indices `2*i` and `2*i+1` diff --git a/src/shims/x86/sse2.rs b/src/shims/x86/sse2.rs index 3fbab9ba78..f712814a5e 100644 --- a/src/shims/x86/sse2.rs +++ b/src/shims/x86/sse2.rs @@ -6,7 +6,7 @@ use rustc_target::callconv::FnAbi; use super::{ FloatBinOp, ShiftOp, bin_op_simd_float_all, bin_op_simd_float_first, convert_float_to_int, - packssdw, packsswb, packuswb, psadbw, shift_simd_by_scalar, + packssdw, packsswb, packuswb, pmaddwd, psadbw, shift_simd_by_scalar, }; use crate::*; @@ -286,33 +286,7 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { let [left, right] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?; - let (left, left_len) = this.project_to_simd(left)?; - let (right, right_len) = this.project_to_simd(right)?; - let (dest, dest_len) = this.project_to_simd(dest)?; - - assert_eq!(left_len, right_len); - assert_eq!(dest_len.strict_mul(2), left_len); - - for i in 0..dest_len { - let j1 = i.strict_mul(2); - let left1 = this.read_scalar(&this.project_index(&left, j1)?)?.to_i16()?; - let right1 = this.read_scalar(&this.project_index(&right, j1)?)?.to_i16()?; - - let j2 = j1.strict_add(1); - let left2 = this.read_scalar(&this.project_index(&left, j2)?)?.to_i16()?; - let right2 = this.read_scalar(&this.project_index(&right, j2)?)?.to_i16()?; - - let dest = this.project_index(&dest, i)?; - - // Multiplications are i16*i16->i32, which will not overflow. - let mul1 = i32::from(left1).strict_mul(right1.into()); - let mul2 = i32::from(left2).strict_mul(right2.into()); - // However, this addition can overflow in the most extreme case - // (-0x8000)*(-0x8000)+(-0x8000)*(-0x8000) = 0x80000000 - let res = mul1.wrapping_add(mul2); - - this.write_scalar(Scalar::from_i32(res), &dest)?; - } + pmaddwd(this, left, right, dest)?; } _ => return interp_ok(EmulateItemResult::NotSupported), } diff --git a/tests/pass/shims/x86/intrinsics-x86-avx512.rs b/tests/pass/shims/x86/intrinsics-x86-avx512.rs index 42acb6c3fb..7cc554ef5a 100644 --- a/tests/pass/shims/x86/intrinsics-x86-avx512.rs +++ b/tests/pass/shims/x86/intrinsics-x86-avx512.rs @@ -100,6 +100,77 @@ unsafe fn test_avx512() { } test_mm512_maddubs_epi16(); + #[target_feature(enable = "avx512bw")] + unsafe fn test_mm512_madd_epi16() { + // Input pairs + // + // - `i16::MIN * i16::MIN + i16::MIN * i16::MIN`: the 32-bit addition overflows + // - `i16::MAX * i16::MAX + i16::MAX * i16::MAX`: check that widening happens before + // arithmetic + // - `i16::MIN * i16::MAX + i16::MAX * i16::MIN`: check that large negative values are + // handled correctly + // - `3 * 1 + 4 * 2`: A sanity check, the result should be 14. + + #[rustfmt::skip] + let a = _mm512_set_epi16( + i16::MIN, i16::MIN, + i16::MAX, i16::MAX, + i16::MIN, i16::MAX, + 3, 1, + + i16::MIN, i16::MIN, + i16::MAX, i16::MAX, + i16::MIN, i16::MAX, + 3, 1, + + i16::MIN, i16::MIN, + i16::MAX, i16::MAX, + i16::MIN, i16::MAX, + 3, 1, + + i16::MIN, i16::MIN, + i16::MAX, i16::MAX, + i16::MIN, i16::MAX, + 3, 1, + ); + + #[rustfmt::skip] + let b = _mm512_set_epi16( + i16::MIN, i16::MIN, + i16::MAX, i16::MAX, + i16::MAX, i16::MIN, + 4, 2, + + i16::MIN, i16::MIN, + i16::MAX, i16::MAX, + i16::MAX, i16::MIN, + 4, 2, + + i16::MIN, i16::MIN, + i16::MAX, i16::MAX, + i16::MAX, i16::MIN, + 4, 2, + + i16::MIN, i16::MIN, + i16::MAX, i16::MAX, + i16::MAX, i16::MIN, + 4, 2, + ); + + let r = _mm512_madd_epi16(a, b); + + #[rustfmt::skip] + let e = _mm512_set_epi32( + i32::MIN, 2_147_352_578, -2_147_418_112, 14, + i32::MIN, 2_147_352_578, -2_147_418_112, 14, + i32::MIN, 2_147_352_578, -2_147_418_112, 14, + i32::MIN, 2_147_352_578, -2_147_418_112, 14, + ); + + assert_eq_m512i(r, e); + } + test_mm512_madd_epi16(); + #[target_feature(enable = "avx512f")] unsafe fn test_mm512_permutexvar_epi32() { let a = _mm512_set_epi32(