Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions benches/bench_f32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ fn minmax_f32_random_array_long(c: &mut Criterion) {
});
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
if is_x86_feature_detected!("avx") {
if is_x86_feature_detected!("avx2") {
c.bench_function("avx_random_long_f32", |b| {
b.iter(|| unsafe { AVX2::argminmax(black_box(data.view())) })
});
Expand Down Expand Up @@ -68,7 +68,7 @@ fn minmax_f32_random_array_short(c: &mut Criterion) {
});
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
if is_x86_feature_detected!("avx") {
if is_x86_feature_detected!("avx2") {
c.bench_function("avx_random_short_f32", |b| {
b.iter(|| unsafe { AVX2::argminmax(black_box(data.view())) })
});
Expand Down Expand Up @@ -109,7 +109,7 @@ fn minmax_f32_worst_case_array_long(c: &mut Criterion) {
});
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
if is_x86_feature_detected!("avx") {
if is_x86_feature_detected!("avx2") {
c.bench_function("avx_worst_long_f32", |b| {
b.iter(|| unsafe { AVX2::argminmax(black_box(data.view())) })
});
Expand Down Expand Up @@ -150,7 +150,7 @@ fn minmax_f32_worst_case_array_short(c: &mut Criterion) {
});
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
if is_x86_feature_detected!("avx") {
if is_x86_feature_detected!("avx2") {
c.bench_function("avx_worst_short_f32", |b| {
b.iter(|| unsafe { AVX2::argminmax(black_box(data.view())) })
});
Expand Down
8 changes: 4 additions & 4 deletions benches/bench_f64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ fn minmax_f64_random_array_long(c: &mut Criterion) {
});
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
if is_x86_feature_detected!("avx") {
if is_x86_feature_detected!("avx2") {
c.bench_function("avx_random_long_f64", |b| {
b.iter(|| unsafe { AVX2::argminmax(black_box(data.view())) })
});
Expand Down Expand Up @@ -54,7 +54,7 @@ fn minmax_f64_random_array_short(c: &mut Criterion) {
});
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
if is_x86_feature_detected!("avx") {
if is_x86_feature_detected!("avx2") {
c.bench_function("avx_random_short_f64", |b| {
b.iter(|| unsafe { AVX2::argminmax(black_box(data.view())) })
});
Expand Down Expand Up @@ -83,7 +83,7 @@ fn minmax_f64_worst_case_array_long(c: &mut Criterion) {
});
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
if is_x86_feature_detected!("avx") {
if is_x86_feature_detected!("avx2") {
c.bench_function("avx_worst_long_f64", |b| {
b.iter(|| unsafe { AVX2::argminmax(black_box(data.view())) })
});
Expand Down Expand Up @@ -112,7 +112,7 @@ fn minmax_f64_worst_case_array_short(c: &mut Criterion) {
});
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
if is_x86_feature_detected!("avx") {
if is_x86_feature_detected!("avx2") {
c.bench_function("avx_worst_short_f64", |b| {
b.iter(|| unsafe { AVX2::argminmax(black_box(data.view())) })
});
Expand Down
3 changes: 0 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,6 @@ macro_rules! impl_argminmax {
return unsafe { AVX512::argminmax(self) }
} else if is_x86_feature_detected!("avx2") {
return unsafe { AVX2::argminmax(self) }
} else if is_x86_feature_detected!("avx") & (<$t>::NB_BITS >= 32) & (<$t>::IS_FLOAT == true) {
// f32 and f64 do not require avx2
return unsafe { AVX2::argminmax(self) }
// SKIP SSE4.2 bc scalar is faster or equivalent for 64 bit numbers
// // } else if is_x86_feature_detected!("sse4.2") & (<$t>::NB_BITS == 64) & (<$t>::IS_FLOAT == false) {
// // SSE4.2 is needed for comparing 64-bit integers
Expand Down
96 changes: 59 additions & 37 deletions src/simd/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@ use crate::scalar::{ScalarArgMinMax, SCALAR};
// TODO: other potential generic SIMDIndexDtype: Copy
#[allow(clippy::missing_safety_doc)] // TODO: add safety docs?
pub trait SIMD<
ScalarDType: Copy + PartialOrd + AsPrimitive<usize>,
SIMDVecDtype: Copy,
ValueDType: Copy + PartialOrd,
SIMDValueDtype: Copy,
IndexDtype: Copy + PartialOrd + AsPrimitive<usize>,
SIMDIndexDtype: Copy,
SIMDMaskDtype: Copy,
const LANE_SIZE: usize,
>
{
const INITIAL_INDEX: SIMDVecDtype;
const MAX_INDEX: usize; // Integers > this value **cannot** be accurately represented in SIMDVecDtype
const INITIAL_INDEX: SIMDIndexDtype;
const MAX_INDEX: usize; // Integers > this value **cannot** be accurately represented in SIMDIndexDtype

#[inline(always)]
fn _find_largest_lower_multiple_of_lane_size(n: usize) -> usize {
Expand All @@ -24,22 +26,34 @@ pub trait SIMD<

// ------------------------------------ SIMD HELPERS --------------------------------------

unsafe fn _reg_to_arr(reg: SIMDVecDtype) -> [ScalarDType; LANE_SIZE];
unsafe fn _reg_to_arr_values(reg: SIMDValueDtype) -> [ValueDType; LANE_SIZE];

unsafe fn _mm_loadu(data: *const ScalarDType) -> SIMDVecDtype;
unsafe fn _reg_to_arr_indices(reg: SIMDIndexDtype) -> [IndexDtype; LANE_SIZE];

unsafe fn _mm_set1(a: usize) -> SIMDVecDtype;
unsafe fn _mm_loadu(data: *const ValueDType) -> SIMDValueDtype;

unsafe fn _mm_add(a: SIMDVecDtype, b: SIMDVecDtype) -> SIMDVecDtype;
unsafe fn _mm_set1(a: usize) -> SIMDIndexDtype;

unsafe fn _mm_cmpgt(a: SIMDVecDtype, b: SIMDVecDtype) -> SIMDMaskDtype;
unsafe fn _mm_add(a: SIMDIndexDtype, b: SIMDIndexDtype) -> SIMDIndexDtype;

unsafe fn _mm_cmplt(a: SIMDVecDtype, b: SIMDVecDtype) -> SIMDMaskDtype;
unsafe fn _mm_cmpgt(a: SIMDValueDtype, b: SIMDValueDtype) -> SIMDMaskDtype;

unsafe fn _mm_blendv(a: SIMDVecDtype, b: SIMDVecDtype, mask: SIMDMaskDtype) -> SIMDVecDtype;
unsafe fn _mm_cmplt(a: SIMDValueDtype, b: SIMDValueDtype) -> SIMDMaskDtype;

unsafe fn _mm_blendv_values(
a: SIMDValueDtype,
b: SIMDValueDtype,
mask: SIMDMaskDtype,
) -> SIMDValueDtype;

unsafe fn _mm_blendv_indices(
a: SIMDIndexDtype,
b: SIMDIndexDtype,
mask: SIMDMaskDtype,
) -> SIMDIndexDtype;

#[inline(always)]
unsafe fn _horiz_min(index: SIMDVecDtype, value: SIMDVecDtype) -> (usize, ScalarDType) {
unsafe fn _horiz_min(index: SIMDIndexDtype, value: SIMDValueDtype) -> (usize, ValueDType) {
// This becomes the bottleneck when using 8-bit data types, as for every 2**7
// or 2**8 elements, the SIMD inner loop is executed (& thus also terminated)
// to avoid overflow.
Expand All @@ -48,14 +62,14 @@ pub trait SIMD<
// see: https://stackoverflow.com/a/9798369
// Note: this is not a bottleneck for 16-bit data types, as the termination of
// the SIMD inner loop is 2**8 times less frequent.
let index_arr = Self::_reg_to_arr(index);
let value_arr = Self::_reg_to_arr(value);
let index_arr = Self::_reg_to_arr_indices(index);
let value_arr = Self::_reg_to_arr_values(value);
let (min_index, min_value) = min_index_value(&index_arr, &value_arr);
(min_index.as_(), min_value)
}

#[inline(always)]
unsafe fn _horiz_max(index: SIMDVecDtype, value: SIMDVecDtype) -> (usize, ScalarDType) {
unsafe fn _horiz_max(index: SIMDIndexDtype, value: SIMDValueDtype) -> (usize, ValueDType) {
// This becomes the bottleneck when using 8-bit data types, as for every 2**7
// or 2**8 elements, the SIMD inner loop is executed (& thus also terminated)
// to avoid overflow.
Expand All @@ -64,14 +78,14 @@ pub trait SIMD<
// see: https://stackoverflow.com/a/9798369
// Note: this is not a bottleneck for 16-bit data types, as the termination of
// the SIMD inner loop is 2**8 times less frequent.
let index_arr = Self::_reg_to_arr(index);
let value_arr = Self::_reg_to_arr(value);
let index_arr = Self::_reg_to_arr_indices(index);
let value_arr = Self::_reg_to_arr_values(value);
let (max_index, max_value) = max_index_value(&index_arr, &value_arr);
(max_index.as_(), max_value)
}

#[inline(always)]
unsafe fn _mm_prefetch(data: *const ScalarDType) {
unsafe fn _mm_prefetch(data: *const ValueDType) {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
#[cfg(target_arch = "x86")]
Expand All @@ -91,20 +105,20 @@ pub trait SIMD<

// ------------------------------------ ARGMINMAX --------------------------------------

unsafe fn argminmax(data: ArrayView1<ScalarDType>) -> (usize, usize);
unsafe fn argminmax(data: ArrayView1<ValueDType>) -> (usize, usize);

#[inline(always)]
unsafe fn _argminmax(data: ArrayView1<ScalarDType>) -> (usize, usize)
unsafe fn _argminmax(data: ArrayView1<ValueDType>) -> (usize, usize)
where
SCALAR: ScalarArgMinMax<ScalarDType>,
SCALAR: ScalarArgMinMax<ValueDType>,
{
argminmax_generic(data, LANE_SIZE, Self::_overflow_safe_core_argminmax)
}

#[inline(always)]
unsafe fn _overflow_safe_core_argminmax(
arr: ArrayView1<ScalarDType>,
) -> (usize, ScalarDType, usize, ScalarDType) {
arr: ArrayView1<ValueDType>,
) -> (usize, ValueDType, usize, ValueDType) {
// 0. Get the max value of the data type - which needs to be divided by LANE_SIZE
let dtype_max = Self::_find_largest_lower_multiple_of_lane_size(Self::MAX_INDEX);

Expand Down Expand Up @@ -171,20 +185,20 @@ pub trait SIMD<
// TODO: can be cleaner (perhaps?)
#[inline(always)]
unsafe fn _get_min_max_index_value(
index_low: SIMDVecDtype,
values_low: SIMDVecDtype,
index_high: SIMDVecDtype,
values_high: SIMDVecDtype,
) -> (usize, ScalarDType, usize, ScalarDType) {
index_low: SIMDIndexDtype,
values_low: SIMDValueDtype,
index_high: SIMDIndexDtype,
values_high: SIMDValueDtype,
) -> (usize, ValueDType, usize, ValueDType) {
let (min_index, min_value) = Self::_horiz_min(index_low, values_low);
let (max_index, max_value) = Self::_horiz_max(index_high, values_high);
(min_index, min_value, max_index, max_value)
}

#[inline(always)]
unsafe fn _core_argminmax(
arr: ArrayView1<ScalarDType>,
) -> (usize, ScalarDType, usize, ScalarDType) {
arr: ArrayView1<ValueDType>,
) -> (usize, ValueDType, usize, ValueDType) {
assert_eq!(arr.len() % LANE_SIZE, 0);
// Efficient calculation of argmin and argmax together
let mut new_index = Self::INITIAL_INDEX;
Expand Down Expand Up @@ -228,12 +242,12 @@ pub trait SIMD<
let gt_mask = Self::_mm_cmpgt(new_values, values_high);

// Update the highest and lowest values
values_low = Self::_mm_blendv(values_low, new_values, lt_mask);
values_high = Self::_mm_blendv(values_high, new_values, gt_mask);
values_low = Self::_mm_blendv_values(values_low, new_values, lt_mask);
values_high = Self::_mm_blendv_values(values_high, new_values, gt_mask);

// Update the index if the new value is lower/higher
index_low = Self::_mm_blendv(index_low, new_index, lt_mask);
index_high = Self::_mm_blendv(index_high, new_index, gt_mask);
index_low = Self::_mm_blendv_indices(index_low, new_index, lt_mask);
index_high = Self::_mm_blendv_indices(index_high, new_index, gt_mask);

// 25 is a non-scientific number, but seems to work overall
// => TODO: probably this should be in function of the data type
Expand All @@ -247,11 +261,15 @@ pub trait SIMD<
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
macro_rules! unimplement_simd {
($scalar_type:ty, $reg:ty, $simd_type:ident) => {
impl SIMD<$scalar_type, $reg, $reg, 0> for $simd_type {
impl SIMD<$scalar_type, $reg, $scalar_type, $reg, $reg, 0> for $simd_type {
const INITIAL_INDEX: $reg = 0;
const MAX_INDEX: usize = 0;

unsafe fn _reg_to_arr(_reg: $reg) -> [$scalar_type; 0] {
unsafe fn _reg_to_arr_values(_reg: $reg) -> [$scalar_type; 0] {
unimplemented!()
}

unsafe fn _reg_to_arr_indices(_reg: $reg) -> [$scalar_type; 0] {
unimplemented!()
}

Expand All @@ -275,7 +293,11 @@ macro_rules! unimplement_simd {
unimplemented!()
}

unsafe fn _mm_blendv(_a: $reg, _b: $reg, _mask: $reg) -> $reg {
unsafe fn _mm_blendv_values(_a: $reg, _b: $reg, _mask: $reg) -> $reg {
unimplemented!()
}

unsafe fn _mm_blendv_indices(_a: $reg, _b: $reg, _mask: $reg) -> $reg {
unimplemented!()
}

Expand Down
Loading