diff --git a/benches/bench_f32.rs b/benches/bench_f32.rs index dc86827..70d4c8b 100644 --- a/benches/bench_f32.rs +++ b/benches/bench_f32.rs @@ -27,7 +27,7 @@ fn minmax_f32_random_array_long(c: &mut Criterion) { }); } #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { + if is_x86_feature_detected!("avx2") { c.bench_function("avx_random_long_f32", |b| { b.iter(|| unsafe { AVX2::argminmax(black_box(data.view())) }) }); @@ -68,7 +68,7 @@ fn minmax_f32_random_array_short(c: &mut Criterion) { }); } #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { + if is_x86_feature_detected!("avx2") { c.bench_function("avx_random_short_f32", |b| { b.iter(|| unsafe { AVX2::argminmax(black_box(data.view())) }) }); @@ -109,7 +109,7 @@ fn minmax_f32_worst_case_array_long(c: &mut Criterion) { }); } #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { + if is_x86_feature_detected!("avx2") { c.bench_function("avx_worst_long_f32", |b| { b.iter(|| unsafe { AVX2::argminmax(black_box(data.view())) }) }); @@ -150,7 +150,7 @@ fn minmax_f32_worst_case_array_short(c: &mut Criterion) { }); } #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { + if is_x86_feature_detected!("avx2") { c.bench_function("avx_worst_short_f32", |b| { b.iter(|| unsafe { AVX2::argminmax(black_box(data.view())) }) }); diff --git a/benches/bench_f64.rs b/benches/bench_f64.rs index a876d8e..820f515 100644 --- a/benches/bench_f64.rs +++ b/benches/bench_f64.rs @@ -25,7 +25,7 @@ fn minmax_f64_random_array_long(c: &mut Criterion) { }); } #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { + if is_x86_feature_detected!("avx2") { c.bench_function("avx_random_long_f64", |b| { b.iter(|| unsafe { AVX2::argminmax(black_box(data.view())) }) }); @@ -54,7 +54,7 @@ fn minmax_f64_random_array_short(c: &mut Criterion) { }); } #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { + if is_x86_feature_detected!("avx2") { c.bench_function("avx_random_short_f64", |b| { b.iter(|| unsafe { AVX2::argminmax(black_box(data.view())) }) }); @@ -83,7 +83,7 @@ fn minmax_f64_worst_case_array_long(c: &mut Criterion) { }); } #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { + if is_x86_feature_detected!("avx2") { c.bench_function("avx_worst_long_f64", |b| { b.iter(|| unsafe { AVX2::argminmax(black_box(data.view())) }) }); @@ -112,7 +112,7 @@ fn minmax_f64_worst_case_array_short(c: &mut Criterion) { }); } #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { + if is_x86_feature_detected!("avx2") { c.bench_function("avx_worst_short_f64", |b| { b.iter(|| unsafe { AVX2::argminmax(black_box(data.view())) }) }); diff --git a/src/lib.rs b/src/lib.rs index 1ccdc5d..f643881 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -117,9 +117,6 @@ macro_rules! impl_argminmax { return unsafe { AVX512::argminmax(self) } } else if is_x86_feature_detected!("avx2") { return unsafe { AVX2::argminmax(self) } - } else if is_x86_feature_detected!("avx") & (<$t>::NB_BITS >= 32) & (<$t>::IS_FLOAT == true) { - // f32 and f64 do not require avx2 - return unsafe { AVX2::argminmax(self) } // SKIP SSE4.2 bc scalar is faster or equivalent for 64 bit numbers // // } else if is_x86_feature_detected!("sse4.2") & (<$t>::NB_BITS == 64) & (<$t>::IS_FLOAT == false) { // // SSE4.2 is needed for comparing 64-bit integers diff --git a/src/simd/generic.rs b/src/simd/generic.rs index 01eaeba..a16718b 100644 --- a/src/simd/generic.rs +++ b/src/simd/generic.rs @@ -8,14 +8,16 @@ use crate::scalar::{ScalarArgMinMax, SCALAR}; // TODO: other potential generic SIMDIndexDtype: Copy #[allow(clippy::missing_safety_doc)] // TODO: add safety docs? pub trait SIMD< - ScalarDType: Copy + PartialOrd + AsPrimitive, - SIMDVecDtype: Copy, + ValueDType: Copy + PartialOrd, + SIMDValueDtype: Copy, + IndexDtype: Copy + PartialOrd + AsPrimitive, + SIMDIndexDtype: Copy, SIMDMaskDtype: Copy, const LANE_SIZE: usize, > { - const INITIAL_INDEX: SIMDVecDtype; - const MAX_INDEX: usize; // Integers > this value **cannot** be accurately represented in SIMDVecDtype + const INITIAL_INDEX: SIMDIndexDtype; + const MAX_INDEX: usize; // Integers > this value **cannot** be accurately represented in SIMDIndexDtype #[inline(always)] fn _find_largest_lower_multiple_of_lane_size(n: usize) -> usize { @@ -24,22 +26,34 @@ pub trait SIMD< // ------------------------------------ SIMD HELPERS -------------------------------------- - unsafe fn _reg_to_arr(reg: SIMDVecDtype) -> [ScalarDType; LANE_SIZE]; + unsafe fn _reg_to_arr_values(reg: SIMDValueDtype) -> [ValueDType; LANE_SIZE]; - unsafe fn _mm_loadu(data: *const ScalarDType) -> SIMDVecDtype; + unsafe fn _reg_to_arr_indices(reg: SIMDIndexDtype) -> [IndexDtype; LANE_SIZE]; - unsafe fn _mm_set1(a: usize) -> SIMDVecDtype; + unsafe fn _mm_loadu(data: *const ValueDType) -> SIMDValueDtype; - unsafe fn _mm_add(a: SIMDVecDtype, b: SIMDVecDtype) -> SIMDVecDtype; + unsafe fn _mm_set1(a: usize) -> SIMDIndexDtype; - unsafe fn _mm_cmpgt(a: SIMDVecDtype, b: SIMDVecDtype) -> SIMDMaskDtype; + unsafe fn _mm_add(a: SIMDIndexDtype, b: SIMDIndexDtype) -> SIMDIndexDtype; - unsafe fn _mm_cmplt(a: SIMDVecDtype, b: SIMDVecDtype) -> SIMDMaskDtype; + unsafe fn _mm_cmpgt(a: SIMDValueDtype, b: SIMDValueDtype) -> SIMDMaskDtype; - unsafe fn _mm_blendv(a: SIMDVecDtype, b: SIMDVecDtype, mask: SIMDMaskDtype) -> SIMDVecDtype; + unsafe fn _mm_cmplt(a: SIMDValueDtype, b: SIMDValueDtype) -> SIMDMaskDtype; + + unsafe fn _mm_blendv_values( + a: SIMDValueDtype, + b: SIMDValueDtype, + mask: SIMDMaskDtype, + ) -> SIMDValueDtype; + + unsafe fn _mm_blendv_indices( + a: SIMDIndexDtype, + b: SIMDIndexDtype, + mask: SIMDMaskDtype, + ) -> SIMDIndexDtype; #[inline(always)] - unsafe fn _horiz_min(index: SIMDVecDtype, value: SIMDVecDtype) -> (usize, ScalarDType) { + unsafe fn _horiz_min(index: SIMDIndexDtype, value: SIMDValueDtype) -> (usize, ValueDType) { // This becomes the bottleneck when using 8-bit data types, as for every 2**7 // or 2**8 elements, the SIMD inner loop is executed (& thus also terminated) // to avoid overflow. @@ -48,14 +62,14 @@ pub trait SIMD< // see: https://stackoverflow.com/a/9798369 // Note: this is not a bottleneck for 16-bit data types, as the termination of // the SIMD inner loop is 2**8 times less frequent. - let index_arr = Self::_reg_to_arr(index); - let value_arr = Self::_reg_to_arr(value); + let index_arr = Self::_reg_to_arr_indices(index); + let value_arr = Self::_reg_to_arr_values(value); let (min_index, min_value) = min_index_value(&index_arr, &value_arr); (min_index.as_(), min_value) } #[inline(always)] - unsafe fn _horiz_max(index: SIMDVecDtype, value: SIMDVecDtype) -> (usize, ScalarDType) { + unsafe fn _horiz_max(index: SIMDIndexDtype, value: SIMDValueDtype) -> (usize, ValueDType) { // This becomes the bottleneck when using 8-bit data types, as for every 2**7 // or 2**8 elements, the SIMD inner loop is executed (& thus also terminated) // to avoid overflow. @@ -64,14 +78,14 @@ pub trait SIMD< // see: https://stackoverflow.com/a/9798369 // Note: this is not a bottleneck for 16-bit data types, as the termination of // the SIMD inner loop is 2**8 times less frequent. - let index_arr = Self::_reg_to_arr(index); - let value_arr = Self::_reg_to_arr(value); + let index_arr = Self::_reg_to_arr_indices(index); + let value_arr = Self::_reg_to_arr_values(value); let (max_index, max_value) = max_index_value(&index_arr, &value_arr); (max_index.as_(), max_value) } #[inline(always)] - unsafe fn _mm_prefetch(data: *const ScalarDType) { + unsafe fn _mm_prefetch(data: *const ValueDType) { #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { #[cfg(target_arch = "x86")] @@ -91,20 +105,20 @@ pub trait SIMD< // ------------------------------------ ARGMINMAX -------------------------------------- - unsafe fn argminmax(data: ArrayView1) -> (usize, usize); + unsafe fn argminmax(data: ArrayView1) -> (usize, usize); #[inline(always)] - unsafe fn _argminmax(data: ArrayView1) -> (usize, usize) + unsafe fn _argminmax(data: ArrayView1) -> (usize, usize) where - SCALAR: ScalarArgMinMax, + SCALAR: ScalarArgMinMax, { argminmax_generic(data, LANE_SIZE, Self::_overflow_safe_core_argminmax) } #[inline(always)] unsafe fn _overflow_safe_core_argminmax( - arr: ArrayView1, - ) -> (usize, ScalarDType, usize, ScalarDType) { + arr: ArrayView1, + ) -> (usize, ValueDType, usize, ValueDType) { // 0. Get the max value of the data type - which needs to be divided by LANE_SIZE let dtype_max = Self::_find_largest_lower_multiple_of_lane_size(Self::MAX_INDEX); @@ -171,11 +185,11 @@ pub trait SIMD< // TODO: can be cleaner (perhaps?) #[inline(always)] unsafe fn _get_min_max_index_value( - index_low: SIMDVecDtype, - values_low: SIMDVecDtype, - index_high: SIMDVecDtype, - values_high: SIMDVecDtype, - ) -> (usize, ScalarDType, usize, ScalarDType) { + index_low: SIMDIndexDtype, + values_low: SIMDValueDtype, + index_high: SIMDIndexDtype, + values_high: SIMDValueDtype, + ) -> (usize, ValueDType, usize, ValueDType) { let (min_index, min_value) = Self::_horiz_min(index_low, values_low); let (max_index, max_value) = Self::_horiz_max(index_high, values_high); (min_index, min_value, max_index, max_value) @@ -183,8 +197,8 @@ pub trait SIMD< #[inline(always)] unsafe fn _core_argminmax( - arr: ArrayView1, - ) -> (usize, ScalarDType, usize, ScalarDType) { + arr: ArrayView1, + ) -> (usize, ValueDType, usize, ValueDType) { assert_eq!(arr.len() % LANE_SIZE, 0); // Efficient calculation of argmin and argmax together let mut new_index = Self::INITIAL_INDEX; @@ -228,12 +242,12 @@ pub trait SIMD< let gt_mask = Self::_mm_cmpgt(new_values, values_high); // Update the highest and lowest values - values_low = Self::_mm_blendv(values_low, new_values, lt_mask); - values_high = Self::_mm_blendv(values_high, new_values, gt_mask); + values_low = Self::_mm_blendv_values(values_low, new_values, lt_mask); + values_high = Self::_mm_blendv_values(values_high, new_values, gt_mask); // Update the index if the new value is lower/higher - index_low = Self::_mm_blendv(index_low, new_index, lt_mask); - index_high = Self::_mm_blendv(index_high, new_index, gt_mask); + index_low = Self::_mm_blendv_indices(index_low, new_index, lt_mask); + index_high = Self::_mm_blendv_indices(index_high, new_index, gt_mask); // 25 is a non-scientific number, but seems to work overall // => TODO: probably this should be in function of the data type @@ -247,11 +261,15 @@ pub trait SIMD< #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] macro_rules! unimplement_simd { ($scalar_type:ty, $reg:ty, $simd_type:ident) => { - impl SIMD<$scalar_type, $reg, $reg, 0> for $simd_type { + impl SIMD<$scalar_type, $reg, $scalar_type, $reg, $reg, 0> for $simd_type { const INITIAL_INDEX: $reg = 0; const MAX_INDEX: usize = 0; - unsafe fn _reg_to_arr(_reg: $reg) -> [$scalar_type; 0] { + unsafe fn _reg_to_arr_values(_reg: $reg) -> [$scalar_type; 0] { + unimplemented!() + } + + unsafe fn _reg_to_arr_indices(_reg: $reg) -> [$scalar_type; 0] { unimplemented!() } @@ -275,7 +293,11 @@ macro_rules! unimplement_simd { unimplemented!() } - unsafe fn _mm_blendv(_a: $reg, _b: $reg, _mask: $reg) -> $reg { + unsafe fn _mm_blendv_values(_a: $reg, _b: $reg, _mask: $reg) -> $reg { + unimplemented!() + } + + unsafe fn _mm_blendv_indices(_a: $reg, _b: $reg, _mask: $reg) -> $reg { unimplemented!() } diff --git a/src/simd/simd_f16.rs b/src/simd/simd_f16.rs index 9c3399c..ae30f9a 100644 --- a/src/simd/simd_f16.rs +++ b/src/simd/simd_f16.rs @@ -57,7 +57,7 @@ mod avx2 { std::mem::transmute::<__m256i, [i16; LANE_SIZE]>(reg) } - impl SIMD for AVX2 { + impl SIMD for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([ 0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16, 8i16, 9i16, 10i16, 11i16, 12i16, @@ -67,7 +67,13 @@ mod avx2 { const MAX_INDEX: usize = i16::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m256i) -> [f16; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m256i) -> [f16; LANE_SIZE] { + // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value + unimplemented!() + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: __m256i) -> [i16; LANE_SIZE] { // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -98,7 +104,12 @@ mod avx2 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + unsafe fn _mm_blendv_values(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + _mm256_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { _mm256_blendv_epi8(a, b, mask) } @@ -287,13 +298,19 @@ mod sse { std::mem::transmute::<__m128i, [i16; LANE_SIZE]>(reg) } - impl SIMD for SSE { + impl SIMD for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16]) }; const MAX_INDEX: usize = i16::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m128i) -> [f16; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m128i) -> [f16; LANE_SIZE] { + // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value + unimplemented!() + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: __m128i) -> [i16; LANE_SIZE] { // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -324,7 +341,12 @@ mod sse { } #[inline(always)] - unsafe fn _mm_blendv(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + unsafe fn _mm_blendv_values(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + _mm_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { _mm_blendv_epi8(a, b, mask) } @@ -492,7 +514,7 @@ mod avx512 { std::mem::transmute::<__m512i, [i16; LANE_SIZE]>(reg) } - impl SIMD for AVX512 { + impl SIMD for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([ 0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16, 8i16, 9i16, 10i16, 11i16, 12i16, @@ -503,7 +525,13 @@ mod avx512 { const MAX_INDEX: usize = i16::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m512i) -> [f16; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m512i) -> [f16; LANE_SIZE] { + // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value + unimplemented!() + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: __m512i) -> [i16; LANE_SIZE] { // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -534,7 +562,12 @@ mod avx512 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m512i, b: __m512i, mask: u32) -> __m512i { + unsafe fn _mm_blendv_values(a: __m512i, b: __m512i, mask: u32) -> __m512i { + _mm512_mask_blend_epi16(mask, a, b) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m512i, b: __m512i, mask: u32) -> __m512i { _mm512_mask_blend_epi16(mask, a, b) } @@ -727,13 +760,19 @@ mod neon { std::mem::transmute::(reg) } - impl SIMD for NEON { - const INITIAL_INDEX: int16x8_t = - unsafe { std::mem::transmute([0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16]) }; - const MAX_INDEX: usize = i16::MAX as usize; + impl SIMD for NEON { + const INITIAL_INDEX: uint16x8_t = + unsafe { std::mem::transmute([0u16, 1u16, 2u16, 3u16, 4u16, 5u16, 6u16, 7u16]) }; + const MAX_INDEX: usize = u16::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: int16x8_t) -> [f16; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: int16x8_t) -> [f16; LANE_SIZE] { + // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value + unimplemented!() + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: uint16x8_t) -> [u16; LANE_SIZE] { // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -746,13 +785,13 @@ mod neon { } #[inline(always)] - unsafe fn _mm_set1(a: usize) -> int16x8_t { - vdupq_n_s16(a as i16) + unsafe fn _mm_set1(a: usize) -> uint16x8_t { + vdupq_n_u16(a as u16) } #[inline(always)] - unsafe fn _mm_add(a: int16x8_t, b: int16x8_t) -> int16x8_t { - vaddq_s16(a, b) + unsafe fn _mm_add(a: uint16x8_t, b: uint16x8_t) -> uint16x8_t { + vaddq_u16(a, b) } #[inline(always)] @@ -766,10 +805,15 @@ mod neon { } #[inline(always)] - unsafe fn _mm_blendv(a: int16x8_t, b: int16x8_t, mask: uint16x8_t) -> int16x8_t { + unsafe fn _mm_blendv_values(a: int16x8_t, b: int16x8_t, mask: uint16x8_t) -> int16x8_t { vbslq_s16(mask, b, a) } + #[inline(always)] + unsafe fn _mm_blendv_indices(a: uint16x8_t, b: uint16x8_t, mask: uint16x8_t) -> uint16x8_t { + vbslq_u16(mask, b, a) + } + // ------------------------------------ ARGMINMAX -------------------------------------- #[target_feature(enable = "neon")] @@ -778,7 +822,7 @@ mod neon { } #[inline(always)] - unsafe fn _horiz_min(index: int16x8_t, value: int16x8_t) -> (usize, f16) { + unsafe fn _horiz_min(index: uint16x8_t, value: int16x8_t) -> (usize, f16) { // 0. Find the minimum value let mut vmin: int16x8_t = value; vmin = vminq_s16(vmin, vextq_s16(vmin, vmin, 4)); @@ -790,23 +834,23 @@ mod neon { // 1. Create a mask with the index of the minimum value let mask = vceqq_s16(value, vmin); // 2. Blend the mask with the index - let search_index = vbslq_s16( + let search_index = vbslq_u16( mask, index, // if mask is 1, use index - vdupq_n_s16(i16::MAX), // if mask is 0, use i16::MAX + vdupq_n_u16(u16::MAX), // if mask is 0, use u16::MAX ); // 3. Find the minimum index - let mut imin: int16x8_t = search_index; - imin = vminq_s16(imin, vextq_s16(imin, imin, 4)); - imin = vminq_s16(imin, vextq_s16(imin, imin, 2)); - imin = vminq_s16(imin, vextq_s16(imin, imin, 1)); - let min_index: usize = vgetq_lane_s16(imin, 0) as usize; + let mut imin: uint16x8_t = search_index; + imin = vminq_u16(imin, vextq_u16(imin, imin, 4)); + imin = vminq_u16(imin, vextq_u16(imin, imin, 2)); + imin = vminq_u16(imin, vextq_u16(imin, imin, 1)); + let min_index: usize = vgetq_lane_u16(imin, 0) as usize; (min_index, _ord_i16_to_f16(min_value)) } #[inline(always)] - unsafe fn _horiz_max(index: int16x8_t, value: int16x8_t) -> (usize, f16) { + unsafe fn _horiz_max(index: uint16x8_t, value: int16x8_t) -> (usize, f16) { // 0. Find the maximum value let mut vmax: int16x8_t = value; vmax = vmaxq_s16(vmax, vextq_s16(vmax, vmax, 4)); @@ -818,17 +862,17 @@ mod neon { // 1. Create a mask with the index of the maximum value let mask = vceqq_s16(value, vmax); // 2. Blend the mask with the index - let search_index = vbslq_s16( + let search_index = vbslq_u16( mask, index, // if mask is 1, use index - vdupq_n_s16(i16::MAX), // if mask is 0, use i16::MAX + vdupq_n_u16(u16::MAX), // if mask is 0, use u16::MAX ); // 3. Find the maximum index - let mut imin: int16x8_t = search_index; - imin = vminq_s16(imin, vextq_s16(imin, imin, 4)); - imin = vminq_s16(imin, vextq_s16(imin, imin, 2)); - imin = vminq_s16(imin, vextq_s16(imin, imin, 1)); - let max_index: usize = vgetq_lane_s16(imin, 0) as usize; + let mut imin: uint16x8_t = search_index; + imin = vminq_u16(imin, vextq_u16(imin, imin, 4)); + imin = vminq_u16(imin, vextq_u16(imin, imin, 2)); + imin = vminq_u16(imin, vextq_u16(imin, imin, 1)); + let max_index: usize = vgetq_lane_u16(imin, 0) as usize; (max_index, _ord_i16_to_f16(max_value)) } diff --git a/src/simd/simd_f32.rs b/src/simd/simd_f32.rs index 35aa7c3..e7546d4 100644 --- a/src/simd/simd_f32.rs +++ b/src/simd/simd_f32.rs @@ -19,33 +19,34 @@ mod avx2 { const LANE_SIZE: usize = AVX2::LANE_SIZE_32; - impl SIMD for AVX2 { - const INITIAL_INDEX: __m256 = unsafe { - std::mem::transmute([ - 0.0f32, 1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32, 7.0f32, - ]) - }; - // https://stackoverflow.com/a/3793950 - const MAX_INDEX: usize = 1 << f32::MANTISSA_DIGITS; + impl SIMD for AVX2 { + const INITIAL_INDEX: __m256i = + unsafe { std::mem::transmute([0i32, 1i32, 2i32, 3i32, 4i32, 5i32, 6i32, 7i32]) }; + const MAX_INDEX: usize = i32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m256) -> [f32; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m256) -> [f32; LANE_SIZE] { std::mem::transmute::<__m256, [f32; LANE_SIZE]>(reg) } + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m256i) -> [i32; LANE_SIZE] { + std::mem::transmute::<__m256i, [i32; LANE_SIZE]>(reg) + } + #[inline(always)] unsafe fn _mm_loadu(data: *const f32) -> __m256 { _mm256_loadu_ps(data as *const f32) } #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m256 { - _mm256_set1_ps(a as f32) + unsafe fn _mm_set1(a: usize) -> __m256i { + _mm256_set1_epi32(a as i32) } #[inline(always)] - unsafe fn _mm_add(a: __m256, b: __m256) -> __m256 { - _mm256_add_ps(a, b) + unsafe fn _mm_add(a: __m256i, b: __m256i) -> __m256i { + _mm256_add_epi32(a, b) } #[inline(always)] @@ -59,13 +60,18 @@ mod avx2 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m256, b: __m256, mask: __m256) -> __m256 { + unsafe fn _mm_blendv_values(a: __m256, b: __m256, mask: __m256) -> __m256 { _mm256_blendv_ps(a, b, mask) } + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m256i, b: __m256i, mask: __m256) -> __m256i { + _mm256_blendv_epi8(a, b, _mm256_castps_si256(mask)) + } + // ------------------------------------ ARGMINMAX -------------------------------------- - #[target_feature(enable = "avx")] + #[target_feature(enable = "avx2")] unsafe fn argminmax(data: ArrayView1) -> (usize, usize) { Self::_argminmax(data) } @@ -89,7 +95,7 @@ mod avx2 { #[test] fn test_both_versions_return_the_same_results() { - if !is_x86_feature_detected!("avx") { + if !is_x86_feature_detected!("avx2") { return; } @@ -104,7 +110,7 @@ mod avx2 { #[test] fn test_first_index_is_returned_when_identical_values_found() { - if !is_x86_feature_detected!("avx") { + if !is_x86_feature_detected!("avx2") { return; } @@ -131,7 +137,7 @@ mod avx2 { #[test] fn test_no_overflow() { - if !is_x86_feature_detected!("avx") { + if !is_x86_feature_detected!("avx2") { return; } @@ -146,7 +152,7 @@ mod avx2 { #[test] fn test_many_random_runs() { - if !is_x86_feature_detected!("avx") { + if !is_x86_feature_detected!("avx2") { return; } @@ -171,30 +177,33 @@ mod sse { const LANE_SIZE: usize = SSE::LANE_SIZE_32; - impl SIMD for SSE { - const INITIAL_INDEX: __m128 = - unsafe { std::mem::transmute([0.0f32, 1.0f32, 2.0f32, 3.0f32]) }; - // https://stackoverflow.com/a/3793950 - const MAX_INDEX: usize = 1 << f32::MANTISSA_DIGITS; + impl SIMD for SSE { + const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i32, 1i32, 2i32, 3i32]) }; + const MAX_INDEX: usize = i32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m128) -> [f32; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m128) -> [f32; LANE_SIZE] { std::mem::transmute::<__m128, [f32; LANE_SIZE]>(reg) } + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m128i) -> [i32; LANE_SIZE] { + std::mem::transmute::<__m128i, [i32; LANE_SIZE]>(reg) + } + #[inline(always)] unsafe fn _mm_loadu(data: *const f32) -> __m128 { _mm_loadu_ps(data as *const f32) } #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m128 { - _mm_set1_ps(a as f32) + unsafe fn _mm_set1(a: usize) -> __m128i { + _mm_set1_epi32(a as i32) } #[inline(always)] - unsafe fn _mm_add(a: __m128, b: __m128) -> __m128 { - _mm_add_ps(a, b) + unsafe fn _mm_add(a: __m128i, b: __m128i) -> __m128i { + _mm_add_epi32(a, b) } #[inline(always)] @@ -208,10 +217,15 @@ mod sse { } #[inline(always)] - unsafe fn _mm_blendv(a: __m128, b: __m128, mask: __m128) -> __m128 { + unsafe fn _mm_blendv_values(a: __m128, b: __m128, mask: __m128) -> __m128 { _mm_blendv_ps(a, b, mask) } + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m128i, b: __m128i, mask: __m128) -> __m128i { + _mm_blendv_epi8(a, b, _mm_castps_si128(mask)) + } + // ------------------------------------ ARGMINMAX -------------------------------------- #[target_feature(enable = "sse4.1")] @@ -303,34 +317,38 @@ mod avx512 { const LANE_SIZE: usize = AVX512::LANE_SIZE_32; - impl SIMD for AVX512 { - const INITIAL_INDEX: __m512 = unsafe { + impl SIMD for AVX512 { + const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([ - 0.0f32, 1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32, 7.0f32, 8.0f32, 9.0f32, - 10.0f32, 11.0f32, 12.0f32, 13.0f32, 14.0f32, 15.0f32, + 0i32, 1i32, 2i32, 3i32, 4i32, 5i32, 6i32, 7i32, 8i32, 9i32, 10i32, 11i32, 12i32, + 13i32, 14i32, 15i32, ]) }; - // https://stackoverflow.com/a/3793950 - const MAX_INDEX: usize = 1 << f32::MANTISSA_DIGITS; + const MAX_INDEX: usize = i32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m512) -> [f32; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m512) -> [f32; LANE_SIZE] { std::mem::transmute::<__m512, [f32; LANE_SIZE]>(reg) } + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m512i) -> [i32; LANE_SIZE] { + std::mem::transmute::<__m512i, [i32; LANE_SIZE]>(reg) + } + #[inline(always)] unsafe fn _mm_loadu(data: *const f32) -> __m512 { _mm512_loadu_ps(data as *const f32) } #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m512 { - _mm512_set1_ps(a as f32) + unsafe fn _mm_set1(a: usize) -> __m512i { + _mm512_set1_epi32(a as i32) } #[inline(always)] - unsafe fn _mm_add(a: __m512, b: __m512) -> __m512 { - _mm512_add_ps(a, b) + unsafe fn _mm_add(a: __m512i, b: __m512i) -> __m512i { + _mm512_add_epi32(a, b) } #[inline(always)] @@ -354,7 +372,7 @@ mod avx512 { // { _mm512_cmp_ps_mask(a, b, _CMP_LT_OQ) } #[inline(always)] - unsafe fn _mm_blendv(a: __m512, b: __m512, mask: u16) -> __m512 { + unsafe fn _mm_blendv_values(a: __m512, b: __m512, mask: u16) -> __m512 { _mm512_mask_blend_ps(mask, a, b) } // unimplemented!("AVX512 blendv instructions for ps require a u16 mask.") @@ -365,6 +383,11 @@ mod avx512 { // _mm512_mask_mov_ps(a, _mm512_castps_si512(mask), b) // } + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m512i, b: __m512i, mask: u16) -> __m512i { + _mm512_mask_blend_epi32(mask, a, b) + } + // ------------------------------------ ARGMINMAX -------------------------------------- #[target_feature(enable = "avx512f")] @@ -473,16 +496,19 @@ mod neon { const LANE_SIZE: usize = NEON::LANE_SIZE_32; - impl SIMD for NEON { - const INITIAL_INDEX: float32x4_t = - unsafe { std::mem::transmute([0.0f32, 1.0f32, 2.0f32, 3.0f32]) }; + impl SIMD for NEON { + const INITIAL_INDEX: uint32x4_t = unsafe { std::mem::transmute([0u32, 1u32, 2u32, 3u32]) }; + const MAX_INDEX: usize = u32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: float32x4_t) -> [f32; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: float32x4_t) -> [f32; LANE_SIZE] { std::mem::transmute::(reg) } - // https://stackoverflow.com/a/3793950 - const MAX_INDEX: usize = 1 << f32::MANTISSA_DIGITS; + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: uint32x4_t) -> [u32; LANE_SIZE] { + std::mem::transmute::(reg) + } #[inline(always)] unsafe fn _mm_loadu(data: *const f32) -> float32x4_t { @@ -490,13 +516,13 @@ mod neon { } #[inline(always)] - unsafe fn _mm_set1(a: usize) -> float32x4_t { - vdupq_n_f32(a as f32) + unsafe fn _mm_set1(a: usize) -> uint32x4_t { + vdupq_n_u32(a as u32) } #[inline(always)] - unsafe fn _mm_add(a: float32x4_t, b: float32x4_t) -> float32x4_t { - vaddq_f32(a, b) + unsafe fn _mm_add(a: uint32x4_t, b: uint32x4_t) -> uint32x4_t { + vaddq_u32(a, b) } #[inline(always)] @@ -510,10 +536,19 @@ mod neon { } #[inline(always)] - unsafe fn _mm_blendv(a: float32x4_t, b: float32x4_t, mask: uint32x4_t) -> float32x4_t { + unsafe fn _mm_blendv_values( + a: float32x4_t, + b: float32x4_t, + mask: uint32x4_t, + ) -> float32x4_t { vbslq_f32(mask, b, a) } + #[inline(always)] + unsafe fn _mm_blendv_indices(a: uint32x4_t, b: uint32x4_t, mask: uint32x4_t) -> uint32x4_t { + vbslq_u32(mask, b, a) + } + // ------------------------------------ ARGMINMAX -------------------------------------- #[target_feature(enable = "neon")] diff --git a/src/simd/simd_f64.rs b/src/simd/simd_f64.rs index e252cde..ce14232 100644 --- a/src/simd/simd_f64.rs +++ b/src/simd/simd_f64.rs @@ -16,33 +16,33 @@ mod avx2 { const LANE_SIZE: usize = AVX2::LANE_SIZE_64; - impl SIMD for AVX2 { - const INITIAL_INDEX: __m256d = - unsafe { std::mem::transmute([0.0f64, 1.0f64, 2.0f64, 3.0f64]) }; - // https://stackoverflow.com/a/3793950 - #[cfg(target_arch = "x86_64")] - const MAX_INDEX: usize = 1 << f64::MANTISSA_DIGITS; - #[cfg(target_arch = "x86")] // https://stackoverflow.com/a/29592369 - const MAX_INDEX: usize = u32::MAX as usize; + impl SIMD for AVX2 { + const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([0i64, 1i64, 2i64, 3i64]) }; + const MAX_INDEX: usize = i64::MAX as usize; // TODO overflow on x86? #[inline(always)] - unsafe fn _reg_to_arr(reg: __m256d) -> [f64; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m256d) -> [f64; LANE_SIZE] { std::mem::transmute::<__m256d, [f64; LANE_SIZE]>(reg) } + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m256i) -> [i64; LANE_SIZE] { + std::mem::transmute::<__m256i, [i64; LANE_SIZE]>(reg) + } + #[inline(always)] unsafe fn _mm_loadu(data: *const f64) -> __m256d { _mm256_loadu_pd(data as *const f64) } #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m256d { - _mm256_set1_pd(a as f64) + unsafe fn _mm_set1(a: usize) -> __m256i { + _mm256_set1_epi64x(a as i64) } #[inline(always)] - unsafe fn _mm_add(a: __m256d, b: __m256d) -> __m256d { - _mm256_add_pd(a, b) + unsafe fn _mm_add(a: __m256i, b: __m256i) -> __m256i { + _mm256_add_epi64(a, b) } #[inline(always)] @@ -56,13 +56,18 @@ mod avx2 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m256d, b: __m256d, mask: __m256d) -> __m256d { + unsafe fn _mm_blendv_values(a: __m256d, b: __m256d, mask: __m256d) -> __m256d { _mm256_blendv_pd(a, b, mask) } + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m256i, b: __m256i, mask: __m256d) -> __m256i { + _mm256_blendv_epi8(a, b, _mm256_castpd_si256(mask)) + } + // ------------------------------------ ARGMINMAX -------------------------------------- - #[target_feature(enable = "avx")] + #[target_feature(enable = "avx2")] unsafe fn argminmax(data: ArrayView1) -> (usize, usize) { Self::_argminmax(data) } @@ -86,7 +91,7 @@ mod avx2 { #[test] fn test_both_versions_return_the_same_results() { - if !is_x86_feature_detected!("avx") { + if !is_x86_feature_detected!("avx2") { return; } @@ -101,7 +106,7 @@ mod avx2 { #[test] fn test_first_index_is_returned_when_identical_values_found() { - if !is_x86_feature_detected!("avx") { + if !is_x86_feature_detected!("avx2") { return; } @@ -128,7 +133,7 @@ mod avx2 { #[test] fn test_many_random_runs() { - if !is_x86_feature_detected!("avx") { + if !is_x86_feature_detected!("avx2") { return; } @@ -153,32 +158,33 @@ mod sse { const LANE_SIZE: usize = SSE::LANE_SIZE_64; - impl SIMD for SSE { - const INITIAL_INDEX: __m128d = unsafe { std::mem::transmute([0.0f64, 1.0f64]) }; - // https://stackoverflow.com/a/3793950 - #[cfg(target_arch = "x86_64")] - const MAX_INDEX: usize = 1 << f64::MANTISSA_DIGITS; - #[cfg(target_arch = "x86")] // https://stackoverflow.com/a/29592369 - const MAX_INDEX: usize = u32::MAX as usize; + impl SIMD for SSE { + const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i64, 1i64]) }; + const MAX_INDEX: usize = i64::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m128d) -> [f64; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m128d) -> [f64; LANE_SIZE] { std::mem::transmute::<__m128d, [f64; LANE_SIZE]>(reg) } + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m128i) -> [i64; LANE_SIZE] { + std::mem::transmute::<__m128i, [i64; LANE_SIZE]>(reg) + } + #[inline(always)] unsafe fn _mm_loadu(data: *const f64) -> __m128d { _mm_loadu_pd(data as *const f64) } #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m128d { - _mm_set1_pd(a as f64) + unsafe fn _mm_set1(a: usize) -> __m128i { + _mm_set1_epi64x(a as i64) } #[inline(always)] - unsafe fn _mm_add(a: __m128d, b: __m128d) -> __m128d { - _mm_add_pd(a, b) + unsafe fn _mm_add(a: __m128i, b: __m128i) -> __m128i { + _mm_add_epi64(a, b) } #[inline(always)] @@ -192,10 +198,15 @@ mod sse { } #[inline(always)] - unsafe fn _mm_blendv(a: __m128d, b: __m128d, mask: __m128d) -> __m128d { + unsafe fn _mm_blendv_values(a: __m128d, b: __m128d, mask: __m128d) -> __m128d { _mm_blendv_pd(a, b, mask) } + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m128i, b: __m128i, mask: __m128d) -> __m128i { + _mm_blendv_epi8(a, b, _mm_castpd_si128(mask)) + } + // ------------------------------------ ARGMINMAX -------------------------------------- #[target_feature(enable = "sse4.1")] @@ -276,36 +287,34 @@ mod avx512 { const LANE_SIZE: usize = AVX512::LANE_SIZE_64; - impl SIMD for AVX512 { - const INITIAL_INDEX: __m512d = unsafe { - std::mem::transmute([ - 0.0f64, 1.0f64, 2.0f64, 3.0f64, 4.0f64, 5.0f64, 6.0f64, 7.0f64, - ]) - }; - // https://stackoverflow.com/a/3793950 - #[cfg(target_arch = "x86_64")] - const MAX_INDEX: usize = 1 << f64::MANTISSA_DIGITS; - #[cfg(target_arch = "x86")] // https://stackoverflow.com/a/29592369 - const MAX_INDEX: usize = u32::MAX as usize; + impl SIMD for AVX512 { + const INITIAL_INDEX: __m512i = + unsafe { std::mem::transmute([0i64, 1i64, 2i64, 3i64, 4i64, 5i64, 6i64, 7i64]) }; + const MAX_INDEX: usize = i64::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m512d) -> [f64; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m512d) -> [f64; LANE_SIZE] { std::mem::transmute::<__m512d, [f64; LANE_SIZE]>(reg) } + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m512i) -> [i64; LANE_SIZE] { + std::mem::transmute::<__m512i, [i64; LANE_SIZE]>(reg) + } + #[inline(always)] unsafe fn _mm_loadu(data: *const f64) -> __m512d { _mm512_loadu_pd(data as *const f64) } #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m512d { - _mm512_set1_pd(a as f64) + unsafe fn _mm_set1(a: usize) -> __m512i { + _mm512_set1_epi64(a as i64) } #[inline(always)] - unsafe fn _mm_add(a: __m512d, b: __m512d) -> __m512d { - _mm512_add_pd(a, b) + unsafe fn _mm_add(a: __m512i, b: __m512i) -> __m512i { + _mm512_add_epi64(a, b) } #[inline(always)] @@ -319,10 +328,15 @@ mod avx512 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m512d, b: __m512d, mask: u8) -> __m512d { + unsafe fn _mm_blendv_values(a: __m512d, b: __m512d, mask: u8) -> __m512d { _mm512_mask_blend_pd(mask, a, b) } + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m512i, b: __m512i, mask: u8) -> __m512i { + _mm512_mask_blend_epi64(mask, a, b) + } + // ------------------------------------ ARGMINMAX -------------------------------------- #[target_feature(enable = "avx512f")] diff --git a/src/simd/simd_i16.rs b/src/simd/simd_i16.rs index 715d095..8471253 100644 --- a/src/simd/simd_i16.rs +++ b/src/simd/simd_i16.rs @@ -19,7 +19,7 @@ mod avx2 { const LANE_SIZE: usize = AVX2::LANE_SIZE_16; - impl SIMD for AVX2 { + impl SIMD for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([ 0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16, 8i16, 9i16, 10i16, 11i16, 12i16, @@ -29,7 +29,12 @@ mod avx2 { const MAX_INDEX: usize = i16::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m256i) -> [i16; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m256i) -> [i16; LANE_SIZE] { + std::mem::transmute::<__m256i, [i16; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m256i) -> [i16; LANE_SIZE] { std::mem::transmute::<__m256i, [i16; LANE_SIZE]>(reg) } @@ -59,7 +64,12 @@ mod avx2 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + unsafe fn _mm_blendv_values(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + _mm256_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { _mm256_blendv_epi8(a, b, mask) } @@ -233,13 +243,18 @@ mod sse { const LANE_SIZE: usize = SSE::LANE_SIZE_16; - impl SIMD for SSE { + impl SIMD for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16]) }; const MAX_INDEX: usize = i16::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m128i) -> [i16; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m128i) -> [i16; LANE_SIZE] { + std::mem::transmute::<__m128i, [i16; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m128i) -> [i16; LANE_SIZE] { std::mem::transmute::<__m128i, [i16; LANE_SIZE]>(reg) } @@ -269,7 +284,12 @@ mod sse { } #[inline(always)] - unsafe fn _mm_blendv(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + unsafe fn _mm_blendv_values(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + _mm_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { _mm_blendv_epi8(a, b, mask) } @@ -422,7 +442,7 @@ mod avx512 { const LANE_SIZE: usize = AVX512::LANE_SIZE_16; - impl SIMD for AVX512 { + impl SIMD for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([ 0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16, 8i16, 9i16, 10i16, 11i16, 12i16, @@ -433,7 +453,12 @@ mod avx512 { const MAX_INDEX: usize = i16::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m512i) -> [i16; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m512i) -> [i16; LANE_SIZE] { + std::mem::transmute::<__m512i, [i16; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m512i) -> [i16; LANE_SIZE] { std::mem::transmute::<__m512i, [i16; LANE_SIZE]>(reg) } @@ -463,7 +488,12 @@ mod avx512 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m512i, b: __m512i, mask: u32) -> __m512i { + unsafe fn _mm_blendv_values(a: __m512i, b: __m512i, mask: u32) -> __m512i { + _mm512_mask_blend_epi16(mask, a, b) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m512i, b: __m512i, mask: u32) -> __m512i { _mm512_mask_blend_epi16(mask, a, b) } @@ -641,29 +671,34 @@ mod neon { const LANE_SIZE: usize = NEON::LANE_SIZE_16; - impl SIMD for NEON { - const INITIAL_INDEX: int16x8_t = - unsafe { std::mem::transmute([0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16]) }; - const MAX_INDEX: usize = i16::MAX as usize; + impl SIMD for NEON { + const INITIAL_INDEX: uint16x8_t = + unsafe { std::mem::transmute([0u16, 1u16, 2u16, 3u16, 4u16, 5u16, 6u16, 7u16]) }; + const MAX_INDEX: usize = u16::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: int16x8_t) -> [i16; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: int16x8_t) -> [i16; LANE_SIZE] { std::mem::transmute::(reg) } + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: uint16x8_t) -> [u16; LANE_SIZE] { + std::mem::transmute::(reg) + } + #[inline(always)] unsafe fn _mm_loadu(data: *const i16) -> int16x8_t { vld1q_s16(data as *const i16) } #[inline(always)] - unsafe fn _mm_set1(a: usize) -> int16x8_t { - vdupq_n_s16(a as i16) + unsafe fn _mm_set1(a: usize) -> uint16x8_t { + vdupq_n_u16(a as u16) } #[inline(always)] - unsafe fn _mm_add(a: int16x8_t, b: int16x8_t) -> int16x8_t { - vaddq_s16(a, b) + unsafe fn _mm_add(a: uint16x8_t, b: uint16x8_t) -> uint16x8_t { + vaddq_u16(a, b) } #[inline(always)] @@ -677,10 +712,15 @@ mod neon { } #[inline(always)] - unsafe fn _mm_blendv(a: int16x8_t, b: int16x8_t, mask: uint16x8_t) -> int16x8_t { + unsafe fn _mm_blendv_values(a: int16x8_t, b: int16x8_t, mask: uint16x8_t) -> int16x8_t { vbslq_s16(mask, b, a) } + #[inline(always)] + unsafe fn _mm_blendv_indices(a: uint16x8_t, b: uint16x8_t, mask: uint16x8_t) -> uint16x8_t { + vbslq_u16(mask, b, a) + } + // ------------------------------------ ARGMINMAX -------------------------------------- #[target_feature(enable = "neon")] @@ -689,7 +729,7 @@ mod neon { } #[inline(always)] - unsafe fn _horiz_min(index: int16x8_t, value: int16x8_t) -> (usize, i16) { + unsafe fn _horiz_min(index: uint16x8_t, value: int16x8_t) -> (usize, i16) { // 0. Find the minimum value let mut vmin: int16x8_t = value; vmin = vminq_s16(vmin, vextq_s16(vmin, vmin, 4)); @@ -701,23 +741,23 @@ mod neon { // 1. Create a mask with the index of the minimum value let mask = vceqq_s16(value, vmin); // 2. Blend the mask with the index - let search_index = vbslq_s16( + let search_index = vbslq_u16( mask, index, // if mask is 1, use index - vdupq_n_s16(i16::MAX), // if mask is 0, use i16::MAX + vdupq_n_u16(u16::MAX), // if mask is 0, use u16::MAX ); // 3. Find the minimum index - let mut imin: int16x8_t = search_index; - imin = vminq_s16(imin, vextq_s16(imin, imin, 4)); - imin = vminq_s16(imin, vextq_s16(imin, imin, 2)); - imin = vminq_s16(imin, vextq_s16(imin, imin, 1)); - let min_index: usize = vgetq_lane_s16(imin, 0) as usize; + let mut imin: uint16x8_t = search_index; + imin = vminq_u16(imin, vextq_u16(imin, imin, 4)); + imin = vminq_u16(imin, vextq_u16(imin, imin, 2)); + imin = vminq_u16(imin, vextq_u16(imin, imin, 1)); + let min_index: usize = vgetq_lane_u16(imin, 0) as usize; (min_index, min_value) } #[inline(always)] - unsafe fn _horiz_max(index: int16x8_t, value: int16x8_t) -> (usize, i16) { + unsafe fn _horiz_max(index: uint16x8_t, value: int16x8_t) -> (usize, i16) { // 0. Find the maximum value let mut vmax: int16x8_t = value; vmax = vmaxq_s16(vmax, vextq_s16(vmax, vmax, 4)); @@ -729,17 +769,17 @@ mod neon { // 1. Create a mask with the index of the maximum value let mask = vceqq_s16(value, vmax); // 2. Blend the mask with the index - let search_index = vbslq_s16( + let search_index = vbslq_u16( mask, index, // if mask is 1, use index - vdupq_n_s16(i16::MAX), // if mask is 0, use i16::MAX + vdupq_n_u16(u16::MAX), // if mask is 0, use u16::MAX ); // 3. Find the maximum index - let mut imin: int16x8_t = search_index; - imin = vminq_s16(imin, vextq_s16(imin, imin, 4)); - imin = vminq_s16(imin, vextq_s16(imin, imin, 2)); - imin = vminq_s16(imin, vextq_s16(imin, imin, 1)); - let max_index: usize = vgetq_lane_s16(imin, 0) as usize; + let mut imin: uint16x8_t = search_index; + imin = vminq_u16(imin, vextq_u16(imin, imin, 4)); + imin = vminq_u16(imin, vextq_u16(imin, imin, 2)); + imin = vminq_u16(imin, vextq_u16(imin, imin, 1)); + let max_index: usize = vgetq_lane_u16(imin, 0) as usize; (max_index, max_value) } diff --git a/src/simd/simd_i32.rs b/src/simd/simd_i32.rs index 54bbb96..3748bf2 100644 --- a/src/simd/simd_i32.rs +++ b/src/simd/simd_i32.rs @@ -19,13 +19,18 @@ mod avx2 { const LANE_SIZE: usize = AVX2::LANE_SIZE_32; - impl SIMD for AVX2 { + impl SIMD for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([0i32, 1i32, 2i32, 3i32, 4i32, 5i32, 6i32, 7i32]) }; const MAX_INDEX: usize = i32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m256i) -> [i32; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m256i) -> [i32; LANE_SIZE] { + std::mem::transmute::<__m256i, [i32; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m256i) -> [i32; LANE_SIZE] { std::mem::transmute::<__m256i, [i32; LANE_SIZE]>(reg) } @@ -55,7 +60,12 @@ mod avx2 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + unsafe fn _mm_blendv_values(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + _mm256_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { _mm256_blendv_epi8(a, b, mask) } @@ -153,12 +163,17 @@ mod sse { const LANE_SIZE: usize = SSE::LANE_SIZE_32; - impl SIMD for SSE { + impl SIMD for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i32, 1i32, 2i32, 3i32]) }; const MAX_INDEX: usize = i32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m128i) -> [i32; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m128i) -> [i32; LANE_SIZE] { + std::mem::transmute::<__m128i, [i32; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m128i) -> [i32; LANE_SIZE] { std::mem::transmute::<__m128i, [i32; LANE_SIZE]>(reg) } @@ -188,7 +203,12 @@ mod sse { } #[inline(always)] - unsafe fn _mm_blendv(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + unsafe fn _mm_blendv_values(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + _mm_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { _mm_blendv_epi8(a, b, mask) } @@ -273,7 +293,7 @@ mod avx512 { const LANE_SIZE: usize = AVX512::LANE_SIZE_32; - impl SIMD for AVX512 { + impl SIMD for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([ 0i32, 1i32, 2i32, 3i32, 4i32, 5i32, 6i32, 7i32, 8i32, 9i32, 10i32, 11i32, 12i32, @@ -283,7 +303,12 @@ mod avx512 { const MAX_INDEX: usize = i32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m512i) -> [i32; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m512i) -> [i32; LANE_SIZE] { + std::mem::transmute::<__m512i, [i32; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m512i) -> [i32; LANE_SIZE] { std::mem::transmute::<__m512i, [i32; LANE_SIZE]>(reg) } @@ -313,7 +338,12 @@ mod avx512 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m512i, b: __m512i, mask: u16) -> __m512i { + unsafe fn _mm_blendv_values(a: __m512i, b: __m512i, mask: u16) -> __m512i { + _mm512_mask_blend_epi32(mask, a, b) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m512i, b: __m512i, mask: u16) -> __m512i { _mm512_mask_blend_epi32(mask, a, b) } @@ -411,28 +441,33 @@ mod neon { const LANE_SIZE: usize = NEON::LANE_SIZE_32; - impl SIMD for NEON { - const INITIAL_INDEX: int32x4_t = unsafe { std::mem::transmute([0i32, 1i32, 2i32, 3i32]) }; - const MAX_INDEX: usize = i32::MAX as usize; + impl SIMD for NEON { + const INITIAL_INDEX: uint32x4_t = unsafe { std::mem::transmute([0u32, 1u32, 2u32, 3u32]) }; + const MAX_INDEX: usize = u32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: int32x4_t) -> [i32; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: int32x4_t) -> [i32; LANE_SIZE] { std::mem::transmute::(reg) } + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: uint32x4_t) -> [u32; LANE_SIZE] { + std::mem::transmute::(reg) + } + #[inline(always)] unsafe fn _mm_loadu(data: *const i32) -> int32x4_t { vld1q_s32(data) } #[inline(always)] - unsafe fn _mm_set1(a: usize) -> int32x4_t { - vdupq_n_s32(a as i32) + unsafe fn _mm_set1(a: usize) -> uint32x4_t { + vdupq_n_u32(a as u32) } #[inline(always)] - unsafe fn _mm_add(a: int32x4_t, b: int32x4_t) -> int32x4_t { - vaddq_s32(a, b) + unsafe fn _mm_add(a: uint32x4_t, b: uint32x4_t) -> uint32x4_t { + vaddq_u32(a, b) } #[inline(always)] @@ -446,10 +481,15 @@ mod neon { } #[inline(always)] - unsafe fn _mm_blendv(a: int32x4_t, b: int32x4_t, mask: uint32x4_t) -> int32x4_t { + unsafe fn _mm_blendv_values(a: int32x4_t, b: int32x4_t, mask: uint32x4_t) -> int32x4_t { vbslq_s32(mask, b, a) } + #[inline(always)] + unsafe fn _mm_blendv_indices(a: uint32x4_t, b: uint32x4_t, mask: uint32x4_t) -> uint32x4_t { + vbslq_u32(mask, b, a) + } + // ------------------------------------ ARGMINMAX -------------------------------------- #[target_feature(enable = "neon")] diff --git a/src/simd/simd_i64.rs b/src/simd/simd_i64.rs index 277ecff..d1846c4 100644 --- a/src/simd/simd_i64.rs +++ b/src/simd/simd_i64.rs @@ -16,12 +16,17 @@ mod avx2 { const LANE_SIZE: usize = AVX2::LANE_SIZE_64; - impl SIMD for AVX2 { + impl SIMD for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([0i64, 1i64, 2i64, 3i64]) }; const MAX_INDEX: usize = i64::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m256i) -> [i64; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m256i) -> [i64; LANE_SIZE] { + std::mem::transmute::<__m256i, [i64; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m256i) -> [i64; LANE_SIZE] { std::mem::transmute::<__m256i, [i64; LANE_SIZE]>(reg) } @@ -51,7 +56,12 @@ mod avx2 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + unsafe fn _mm_blendv_values(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + _mm256_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { _mm256_blendv_epi8(a, b, mask) } @@ -149,12 +159,17 @@ mod sse { const LANE_SIZE: usize = SSE::LANE_SIZE_64; - impl SIMD for SSE { + impl SIMD for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i64, 1i64]) }; const MAX_INDEX: usize = i64::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m128i) -> [i64; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m128i) -> [i64; LANE_SIZE] { + std::mem::transmute::<__m128i, [i64; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m128i) -> [i64; LANE_SIZE] { std::mem::transmute::<__m128i, [i64; LANE_SIZE]>(reg) } @@ -184,7 +199,12 @@ mod sse { } #[inline(always)] - unsafe fn _mm_blendv(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + unsafe fn _mm_blendv_values(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + _mm_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { _mm_blendv_epi8(a, b, mask) } @@ -269,13 +289,18 @@ mod avx512 { const LANE_SIZE: usize = AVX512::LANE_SIZE_64; - impl SIMD for AVX512 { + impl SIMD for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([0i64, 1i64, 2i64, 3i64, 4i64, 5i64, 6i64, 7i64]) }; const MAX_INDEX: usize = i64::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m512i) -> [i64; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m512i) -> [i64; LANE_SIZE] { + std::mem::transmute::<__m512i, [i64; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m512i) -> [i64; LANE_SIZE] { std::mem::transmute::<__m512i, [i64; LANE_SIZE]>(reg) } @@ -305,7 +330,12 @@ mod avx512 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m512i, b: __m512i, mask: u8) -> __m512i { + unsafe fn _mm_blendv_values(a: __m512i, b: __m512i, mask: u8) -> __m512i { + _mm512_mask_blend_epi64(mask, a, b) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m512i, b: __m512i, mask: u8) -> __m512i { _mm512_mask_blend_epi64(mask, a, b) } diff --git a/src/simd/simd_i8.rs b/src/simd/simd_i8.rs index a2eefbe..4636665 100644 --- a/src/simd/simd_i8.rs +++ b/src/simd/simd_i8.rs @@ -19,7 +19,7 @@ mod avx2 { const LANE_SIZE: usize = AVX2::LANE_SIZE_8; - impl SIMD for AVX2 { + impl SIMD for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([ 0i8, 1i8, 2i8, 3i8, 4i8, 5i8, 6i8, 7i8, 8i8, 9i8, 10i8, 11i8, 12i8, 13i8, 14i8, @@ -30,7 +30,12 @@ mod avx2 { const MAX_INDEX: usize = i8::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m256i) -> [i8; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m256i) -> [i8; LANE_SIZE] { + std::mem::transmute::<__m256i, [i8; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m256i) -> [i8; LANE_SIZE] { std::mem::transmute::<__m256i, [i8; LANE_SIZE]>(reg) } @@ -60,7 +65,12 @@ mod avx2 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + unsafe fn _mm_blendv_values(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + _mm256_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { _mm256_blendv_epi8(a, b, mask) } @@ -228,7 +238,7 @@ mod sse { const LANE_SIZE: usize = SSE::LANE_SIZE_8; - impl SIMD for SSE { + impl SIMD for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([ 0i8, 1i8, 2i8, 3i8, 4i8, 5i8, 6i8, 7i8, 8i8, 9i8, 10i8, 11i8, 12i8, 13i8, 14i8, @@ -238,7 +248,12 @@ mod sse { const MAX_INDEX: usize = i8::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m128i) -> [i8; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m128i) -> [i8; LANE_SIZE] { + std::mem::transmute::<__m128i, [i8; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m128i) -> [i8; LANE_SIZE] { std::mem::transmute::<__m128i, [i8; LANE_SIZE]>(reg) } @@ -268,7 +283,12 @@ mod sse { } #[inline(always)] - unsafe fn _mm_blendv(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + unsafe fn _mm_blendv_values(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + _mm_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { _mm_blendv_epi8(a, b, mask) } @@ -415,7 +435,7 @@ mod avx512 { const LANE_SIZE: usize = AVX512::LANE_SIZE_8; - impl SIMD for AVX512 { + impl SIMD for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([ 0i8, 1i8, 2i8, 3i8, 4i8, 5i8, 6i8, 7i8, 8i8, 9i8, 10i8, 11i8, 12i8, 13i8, 14i8, @@ -428,7 +448,12 @@ mod avx512 { const MAX_INDEX: usize = i8::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m512i) -> [i8; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m512i) -> [i8; LANE_SIZE] { + std::mem::transmute::<__m512i, [i8; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m512i) -> [i8; LANE_SIZE] { std::mem::transmute::<__m512i, [i8; LANE_SIZE]>(reg) } @@ -458,7 +483,12 @@ mod avx512 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m512i, b: __m512i, mask: u64) -> __m512i { + unsafe fn _mm_blendv_values(a: __m512i, b: __m512i, mask: u64) -> __m512i { + _mm512_mask_blend_epi8(mask, a, b) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m512i, b: __m512i, mask: u64) -> __m512i { _mm512_mask_blend_epi8(mask, a, b) } @@ -630,20 +660,25 @@ mod neon { const LANE_SIZE: usize = NEON::LANE_SIZE_8; - impl SIMD for NEON { - const INITIAL_INDEX: int8x16_t = unsafe { + impl SIMD for NEON { + const INITIAL_INDEX: uint8x16_t = unsafe { std::mem::transmute([ - 0i8, 1i8, 2i8, 3i8, 4i8, 5i8, 6i8, 7i8, 8i8, 9i8, 10i8, 11i8, 12i8, 13i8, 14i8, - 15i8, + 0u8, 1u8, 2u8, 3u8, 4u8, 5u8, 6u8, 7u8, 8u8, 9u8, 10u8, 11u8, 12u8, 13u8, 14u8, + 15u8, ]) }; - const MAX_INDEX: usize = i8::MAX as usize; + const MAX_INDEX: usize = u8::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: int8x16_t) -> [i8; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: int8x16_t) -> [i8; LANE_SIZE] { std::mem::transmute::(reg) } + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: uint8x16_t) -> [u8; LANE_SIZE] { + std::mem::transmute::(reg) + } + #[inline(always)] unsafe fn _mm_loadu(data: *const i8) -> int8x16_t { // TODO: requires v7 @@ -651,13 +686,13 @@ mod neon { } #[inline(always)] - unsafe fn _mm_set1(a: usize) -> int8x16_t { - vdupq_n_s8(a as i8) + unsafe fn _mm_set1(a: usize) -> uint8x16_t { + vdupq_n_u8(a as u8) } #[inline(always)] - unsafe fn _mm_add(a: int8x16_t, b: int8x16_t) -> int8x16_t { - vaddq_s8(a, b) + unsafe fn _mm_add(a: uint8x16_t, b: uint8x16_t) -> uint8x16_t { + vaddq_u8(a, b) } #[inline(always)] @@ -671,10 +706,15 @@ mod neon { } #[inline(always)] - unsafe fn _mm_blendv(a: int8x16_t, b: int8x16_t, mask: uint8x16_t) -> int8x16_t { + unsafe fn _mm_blendv_values(a: int8x16_t, b: int8x16_t, mask: uint8x16_t) -> int8x16_t { vbslq_s8(mask, b, a) } + #[inline(always)] + unsafe fn _mm_blendv_indices(a: uint8x16_t, b: uint8x16_t, mask: uint8x16_t) -> uint8x16_t { + vbslq_u8(mask, b, a) + } + // ------------------------------------ ARGMINMAX -------------------------------------- #[target_feature(enable = "neon")] @@ -683,7 +723,7 @@ mod neon { } #[inline(always)] - unsafe fn _horiz_min(index: int8x16_t, value: int8x16_t) -> (usize, i8) { + unsafe fn _horiz_min(index: uint8x16_t, value: int8x16_t) -> (usize, i8) { // 0. Find the minimum value let mut vmin: int8x16_t = value; vmin = vminq_s8(vmin, vextq_s8(vmin, vmin, 8)); @@ -696,24 +736,24 @@ mod neon { // 1. Create a mask with the index of the minimum value let mask = vceqq_s8(value, vmin); // 2. Blend the mask with the index - let search_index = vbslq_s8( + let search_index = vbslq_u8( mask, index, // if mask is 1, use index - vdupq_n_s8(i8::MAX), // if mask is 0, use i8::MAX + vdupq_n_u8(u8::MAX), // if mask is 0, use u8::MAX ); // 3. Find the minimum index - let mut imin: int8x16_t = search_index; - imin = vminq_s8(imin, vextq_s8(imin, imin, 8)); - imin = vminq_s8(imin, vextq_s8(imin, imin, 4)); - imin = vminq_s8(imin, vextq_s8(imin, imin, 2)); - imin = vminq_s8(imin, vextq_s8(imin, imin, 1)); - let min_index: usize = vgetq_lane_s8(imin, 0) as usize; + let mut imin: uint8x16_t = search_index; + imin = vminq_u8(imin, vextq_u8(imin, imin, 8)); + imin = vminq_u8(imin, vextq_u8(imin, imin, 4)); + imin = vminq_u8(imin, vextq_u8(imin, imin, 2)); + imin = vminq_u8(imin, vextq_u8(imin, imin, 1)); + let min_index: usize = vgetq_lane_u8(imin, 0) as usize; (min_index, min_value) } #[inline(always)] - unsafe fn _horiz_max(index: int8x16_t, value: int8x16_t) -> (usize, i8) { + unsafe fn _horiz_max(index: uint8x16_t, value: int8x16_t) -> (usize, i8) { // 0. Find the maximum value let mut vmax: int8x16_t = value; vmax = vmaxq_s8(vmax, vextq_s8(vmax, vmax, 8)); @@ -726,18 +766,18 @@ mod neon { // 1. Create a mask with the index of the maximum value let mask = vceqq_s8(value, vmax); // 2. Blend the mask with the index - let search_index = vbslq_s8( + let search_index = vbslq_u8( mask, index, // if mask is 1, use index - vdupq_n_s8(i8::MAX), // if mask is 0, use i8::MAX + vdupq_n_u8(u8::MAX), // if mask is 0, use u8::MAX ); // 3. Find the maximum index - let mut imin: int8x16_t = search_index; - imin = vminq_s8(imin, vextq_s8(imin, imin, 8)); - imin = vminq_s8(imin, vextq_s8(imin, imin, 4)); - imin = vminq_s8(imin, vextq_s8(imin, imin, 2)); - imin = vminq_s8(imin, vextq_s8(imin, imin, 1)); - let max_index: usize = vgetq_lane_s8(imin, 0) as usize; + let mut imin: uint8x16_t = search_index; + imin = vminq_u8(imin, vextq_u8(imin, imin, 8)); + imin = vminq_u8(imin, vextq_u8(imin, imin, 4)); + imin = vminq_u8(imin, vextq_u8(imin, imin, 2)); + imin = vminq_u8(imin, vextq_u8(imin, imin, 1)); + let max_index: usize = vgetq_lane_u8(imin, 0) as usize; (max_index, max_value) } diff --git a/src/simd/simd_u16.rs b/src/simd/simd_u16.rs index c07047b..8dda657 100644 --- a/src/simd/simd_u16.rs +++ b/src/simd/simd_u16.rs @@ -42,7 +42,7 @@ mod avx2 { std::mem::transmute::<__m256i, [i16; LANE_SIZE]>(reg) } - impl SIMD for AVX2 { + impl SIMD for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([ 0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16, 8i16, 9i16, 10i16, 11i16, 12i16, @@ -52,7 +52,13 @@ mod avx2 { const MAX_INDEX: usize = i16::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m256i) -> [u16; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m256i) -> [u16; LANE_SIZE] { + // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value + unimplemented!() + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: __m256i) -> [i16; LANE_SIZE] { // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -83,7 +89,12 @@ mod avx2 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + unsafe fn _mm_blendv_values(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + _mm256_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { _mm256_blendv_epi8(a, b, mask) } @@ -283,13 +294,19 @@ mod sse { std::mem::transmute::<__m128i, [i16; LANE_SIZE]>(reg) } - impl SIMD for SSE { + impl SIMD for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16]) }; const MAX_INDEX: usize = i16::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m128i) -> [u16; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m128i) -> [u16; LANE_SIZE] { + // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value + unimplemented!() + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: __m128i) -> [i16; LANE_SIZE] { // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -320,7 +337,12 @@ mod sse { } #[inline(always)] - unsafe fn _mm_blendv(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + unsafe fn _mm_blendv_values(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + _mm_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { _mm_blendv_epi8(a, b, mask) } @@ -502,7 +524,7 @@ mod avx512 { std::mem::transmute::<__m512i, [i16; LANE_SIZE]>(reg) } - impl SIMD for AVX512 { + impl SIMD for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([ 0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16, 8i16, 9i16, 10i16, 11i16, 12i16, @@ -513,7 +535,12 @@ mod avx512 { const MAX_INDEX: usize = i16::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m512i) -> [u16; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m512i) -> [u16; LANE_SIZE] { + unimplemented!("We work with decrordi16 and override _get_min_index_value and _get_max_index_value") + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: __m512i) -> [i16; LANE_SIZE] { unimplemented!("We work with decrordi16 and override _get_min_index_value and _get_max_index_value") } @@ -543,7 +570,12 @@ mod avx512 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m512i, b: __m512i, mask: u32) -> __m512i { + unsafe fn _mm_blendv_values(a: __m512i, b: __m512i, mask: u32) -> __m512i { + _mm512_mask_blend_epi16(mask, a, b) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m512i, b: __m512i, mask: u32) -> __m512i { _mm512_mask_blend_epi16(mask, a, b) } @@ -735,13 +767,18 @@ mod neon { const LANE_SIZE: usize = NEON::LANE_SIZE_16; - impl SIMD for NEON { + impl SIMD for NEON { const INITIAL_INDEX: uint16x8_t = unsafe { std::mem::transmute([0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16]) }; const MAX_INDEX: usize = u16::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: uint16x8_t) -> [u16; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: uint16x8_t) -> [u16; LANE_SIZE] { + std::mem::transmute::(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: uint16x8_t) -> [u16; LANE_SIZE] { std::mem::transmute::(reg) } @@ -771,7 +808,12 @@ mod neon { } #[inline(always)] - unsafe fn _mm_blendv(a: uint16x8_t, b: uint16x8_t, mask: uint16x8_t) -> uint16x8_t { + unsafe fn _mm_blendv_values(a: uint16x8_t, b: uint16x8_t, mask: uint16x8_t) -> uint16x8_t { + vbslq_u16(mask, b, a) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: uint16x8_t, b: uint16x8_t, mask: uint16x8_t) -> uint16x8_t { vbslq_u16(mask, b, a) } diff --git a/src/simd/simd_u32.rs b/src/simd/simd_u32.rs index 9eaf16a..ab68bc8 100644 --- a/src/simd/simd_u32.rs +++ b/src/simd/simd_u32.rs @@ -42,13 +42,19 @@ mod avx2 { std::mem::transmute::<__m256i, [i32; LANE_SIZE]>(reg) } - impl SIMD for AVX2 { + impl SIMD for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([0i32, 1i32, 2i32, 3i32, 4i32, 5i32, 6i32, 7i32]) }; const MAX_INDEX: usize = i32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m256i) -> [u32; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m256i) -> [u32; LANE_SIZE] { + // Not used because we work with i32ord and override _get_min_index_value and _get_max_index_value + unimplemented!() + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: __m256i) -> [i32; LANE_SIZE] { // Not used because we work with i32ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -79,7 +85,12 @@ mod avx2 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + unsafe fn _mm_blendv_values(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + _mm256_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { _mm256_blendv_epi8(a, b, mask) } @@ -209,12 +220,18 @@ mod sse { std::mem::transmute::<__m128i, [i32; LANE_SIZE]>(reg) } - impl SIMD for SSE { + impl SIMD for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i32, 1i32, 2i32, 3i32]) }; const MAX_INDEX: usize = i32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m128i) -> [u32; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m128i) -> [u32; LANE_SIZE] { + // Not used because we work with i32ord and override _get_min_index_value and _get_max_index_value + unimplemented!() + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: __m128i) -> [i32; LANE_SIZE] { // Not used because we work with i32ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -245,7 +262,12 @@ mod sse { } #[inline(always)] - unsafe fn _mm_blendv(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + unsafe fn _mm_blendv_values(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + _mm_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { _mm_blendv_epi8(a, b, mask) } @@ -365,7 +387,7 @@ mod avx512 { std::mem::transmute::<__m512i, [i32; LANE_SIZE]>(reg) } - impl SIMD for AVX512 { + impl SIMD for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([ 0i32, 1i32, 2i32, 3i32, 4i32, 5i32, 6i32, 7i32, 8i32, 9i32, 10i32, 11i32, 12i32, @@ -375,7 +397,12 @@ mod avx512 { const MAX_INDEX: usize = i32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m512i) -> [u32; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m512i) -> [u32; LANE_SIZE] { + unimplemented!("We work with decrordu32 and override _get_min_index_value and _get_max_index_value") + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: __m512i) -> [i32; LANE_SIZE] { unimplemented!("We work with decrordu32 and override _get_min_index_value and _get_max_index_value") } @@ -405,7 +432,12 @@ mod avx512 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m512i, b: __m512i, mask: u16) -> __m512i { + unsafe fn _mm_blendv_values(a: __m512i, b: __m512i, mask: u16) -> __m512i { + _mm512_mask_blend_epi32(mask, a, b) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m512i, b: __m512i, mask: u16) -> __m512i { _mm512_mask_blend_epi32(mask, a, b) } @@ -522,12 +554,17 @@ mod neon { const LANE_SIZE: usize = NEON::LANE_SIZE_32; - impl SIMD for NEON { + impl SIMD for NEON { const INITIAL_INDEX: uint32x4_t = unsafe { std::mem::transmute([0u32, 1u32, 2u32, 3u32]) }; const MAX_INDEX: usize = u32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: uint32x4_t) -> [u32; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: uint32x4_t) -> [u32; LANE_SIZE] { + std::mem::transmute::(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: uint32x4_t) -> [u32; LANE_SIZE] { std::mem::transmute::(reg) } @@ -557,7 +594,12 @@ mod neon { } #[inline(always)] - unsafe fn _mm_blendv(a: uint32x4_t, b: uint32x4_t, mask: uint32x4_t) -> uint32x4_t { + unsafe fn _mm_blendv_values(a: uint32x4_t, b: uint32x4_t, mask: uint32x4_t) -> uint32x4_t { + vbslq_u32(mask, b, a) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: uint32x4_t, b: uint32x4_t, mask: uint32x4_t) -> uint32x4_t { vbslq_u32(mask, b, a) } diff --git a/src/simd/simd_u64.rs b/src/simd/simd_u64.rs index 39b40df..5eca73b 100644 --- a/src/simd/simd_u64.rs +++ b/src/simd/simd_u64.rs @@ -39,12 +39,18 @@ mod avx2 { std::mem::transmute::<__m256i, [i64; LANE_SIZE]>(reg) } - impl SIMD for AVX2 { + impl SIMD for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([0i64, 1i64, 2i64, 3i64]) }; const MAX_INDEX: usize = i64::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m256i) -> [u64; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m256i) -> [u64; LANE_SIZE] { + // Not used because we work with i64ord and override _get_min_index_value and _get_max_index_value + unimplemented!() + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: __m256i) -> [i64; LANE_SIZE] { // Not used because we work with i64ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -75,7 +81,12 @@ mod avx2 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + unsafe fn _mm_blendv_values(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + _mm256_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { _mm256_blendv_epi8(a, b, mask) } @@ -205,12 +216,18 @@ mod sse { std::mem::transmute::<__m128i, [i64; LANE_SIZE]>(reg) } - impl SIMD for SSE { + impl SIMD for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i64, 1i64]) }; const MAX_INDEX: usize = i64::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m128i) -> [u64; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m128i) -> [u64; LANE_SIZE] { + // Not used because we work with i64ord and override _get_min_index_value and _get_max_index_value + unimplemented!() + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: __m128i) -> [i64; LANE_SIZE] { // Not used because we work with i64ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -241,7 +258,12 @@ mod sse { } #[inline(always)] - unsafe fn _mm_blendv(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + unsafe fn _mm_blendv_values(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + _mm_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { _mm_blendv_epi8(a, b, mask) } @@ -361,13 +383,18 @@ mod avx512 { std::mem::transmute::<__m512i, [i64; LANE_SIZE]>(reg) } - impl SIMD for AVX512 { + impl SIMD for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([0i64, 1i64, 2i64, 3i64, 4i64, 5i64, 6i64, 7i64]) }; const MAX_INDEX: usize = i64::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m512i) -> [u64; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m512i) -> [u64; LANE_SIZE] { + unimplemented!("We work with decrordi64 and override _get_min_index_value and _get_max_index_value") + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: __m512i) -> [i64; LANE_SIZE] { unimplemented!("We work with decrordi64 and override _get_min_index_value and _get_max_index_value") } @@ -397,7 +424,12 @@ mod avx512 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m512i, b: __m512i, mask: u8) -> __m512i { + unsafe fn _mm_blendv_values(a: __m512i, b: __m512i, mask: u8) -> __m512i { + _mm512_mask_blend_epi64(mask, a, b) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m512i, b: __m512i, mask: u8) -> __m512i { _mm512_mask_blend_epi64(mask, a, b) } diff --git a/src/simd/simd_u8.rs b/src/simd/simd_u8.rs index b6ababf..aefe891 100644 --- a/src/simd/simd_u8.rs +++ b/src/simd/simd_u8.rs @@ -42,7 +42,7 @@ mod avx2 { std::mem::transmute::<__m256i, [i8; LANE_SIZE]>(reg) } - impl SIMD for AVX2 { + impl SIMD for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([ 0i8, 1i8, 2i8, 3i8, 4i8, 5i8, 6i8, 7i8, 8i8, 9i8, 10i8, 11i8, 12i8, 13i8, 14i8, @@ -53,7 +53,13 @@ mod avx2 { const MAX_INDEX: usize = i8::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m256i) -> [u8; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m256i) -> [u8; LANE_SIZE] { + // Not used because we work with i8ord and override _get_min_index_value and _get_max_index_value + unimplemented!() + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: __m256i) -> [i8; LANE_SIZE] { // Not used because we work with i8ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -84,7 +90,12 @@ mod avx2 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + unsafe fn _mm_blendv_values(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + _mm256_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { _mm256_blendv_epi8(a, b, mask) } @@ -278,7 +289,7 @@ mod sse { std::mem::transmute::<__m128i, [i8; LANE_SIZE]>(reg) } - impl SIMD for SSE { + impl SIMD for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([ 0i8, 1i8, 2i8, 3i8, 4i8, 5i8, 6i8, 7i8, 8i8, 9i8, 10i8, 11i8, 12i8, 13i8, 14i8, @@ -288,7 +299,13 @@ mod sse { const MAX_INDEX: usize = i8::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m128i) -> [u8; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m128i) -> [u8; LANE_SIZE] { + // Not used because we work with i8ord and override _get_min_index_value and _get_max_index_value + unimplemented!() + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: __m128i) -> [i8; LANE_SIZE] { // Not used because we work with i8ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -319,7 +336,12 @@ mod sse { } #[inline(always)] - unsafe fn _mm_blendv(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + unsafe fn _mm_blendv_values(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + _mm_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { _mm_blendv_epi8(a, b, mask) } @@ -495,7 +517,7 @@ mod avx512 { std::mem::transmute::<__m512i, [i8; LANE_SIZE]>(reg) } - impl SIMD for AVX512 { + impl SIMD for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([ 0i8, 1i8, 2i8, 3i8, 4i8, 5i8, 6i8, 7i8, 8i8, 9i8, 10i8, 11i8, 12i8, 13i8, 14i8, @@ -508,7 +530,14 @@ mod avx512 { const MAX_INDEX: usize = i8::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m512i) -> [u8; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m512i) -> [u8; LANE_SIZE] { + unimplemented!( + "We work with decrordi8 and override _get_min_index_value and _get_max_index_value" + ) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: __m512i) -> [i8; LANE_SIZE] { unimplemented!( "We work with decrordi8 and override _get_min_index_value and _get_max_index_value" ) @@ -540,7 +569,12 @@ mod avx512 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m512i, b: __m512i, mask: u64) -> __m512i { + unsafe fn _mm_blendv_values(a: __m512i, b: __m512i, mask: u64) -> __m512i { + _mm512_mask_blend_epi8(mask, a, b) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m512i, b: __m512i, mask: u64) -> __m512i { _mm512_mask_blend_epi8(mask, a, b) } @@ -726,7 +760,7 @@ mod neon { const LANE_SIZE: usize = NEON::LANE_SIZE_8; - impl SIMD for NEON { + impl SIMD for NEON { const INITIAL_INDEX: uint8x16_t = unsafe { std::mem::transmute([ 0u8, 1u8, 2u8, 3u8, 4u8, 5u8, 6u8, 7u8, 8u8, 9u8, 10u8, 11u8, 12u8, 13u8, 14u8, @@ -736,7 +770,12 @@ mod neon { const MAX_INDEX: usize = u8::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: uint8x16_t) -> [u8; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: uint8x16_t) -> [u8; LANE_SIZE] { + std::mem::transmute::(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: uint8x16_t) -> [u8; LANE_SIZE] { std::mem::transmute::(reg) } @@ -766,7 +805,12 @@ mod neon { } #[inline(always)] - unsafe fn _mm_blendv(a: uint8x16_t, b: uint8x16_t, mask: uint8x16_t) -> uint8x16_t { + unsafe fn _mm_blendv_values(a: uint8x16_t, b: uint8x16_t, mask: uint8x16_t) -> uint8x16_t { + vbslq_u8(mask, b, a) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: uint8x16_t, b: uint8x16_t, mask: uint8x16_t) -> uint8x16_t { vbslq_u8(mask, b, a) } diff --git a/src/utils.rs b/src/utils.rs index 55fd93a..d91a177 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -23,7 +23,11 @@ // } #[inline(always)] -pub(crate) fn min_index_value(index: &[T], values: &[T]) -> (T, T) { +pub(crate) fn min_index_value(index: &[Tidx], values: &[Tval]) -> (Tidx, Tval) +where + Tidx: Copy + PartialOrd, + Tval: Copy + PartialOrd, +{ assert_eq!(index.len(), values.len()); values .iter() @@ -62,7 +66,11 @@ pub(crate) fn min_index_value(index: &[T], values: &[T]) - // } #[inline(always)] -pub(crate) fn max_index_value(index: &[T], values: &[T]) -> (T, T) { +pub(crate) fn max_index_value(index: &[Tidx], values: &[Tval]) -> (Tidx, Tval) +where + Tidx: Copy + PartialOrd, + Tval: Copy + PartialOrd, +{ assert_eq!(index.len(), values.len()); values .iter()