Skip to content

Commit

Permalink
Revise unroll scheme
Browse files Browse the repository at this point in the history
- Simpler code, with less duplication
- Faster compile times
- Wider unroll with smaller footprint
- Better perf
  • Loading branch information
Voultapher committed Jul 18, 2023
1 parent 076d231 commit 923d106
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 149 deletions.
14 changes: 10 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,15 @@ impl<T: Freeze> const IsFreeze for T {

#[must_use]
const fn has_direct_interior_mutability<T>() -> bool {
// - Can the type have interior mutability, this is checked by testing if T is Freeze.
// If the type can have interior mutability it may alter itself during comparison in a way
// that must be observed after the sort operation concludes.
// Otherwise a type like Mutex<Option<Box<str>>> could lead to double free.
// Can the type have interior mutability, this is checked by testing if T is Freeze. If the type
// can have interior mutability it may alter itself during comparison in a way that must be
// observed after the sort operation concludes. Otherwise a type like Mutex<Option<Box<str>>>
// could lead to double free.
!<T as IsFreeze>::value()
}

#[must_use]
const fn is_int_like_type<T>() -> bool {
// A heuristic that guesses whether a type looks like an int for optimization purposes.
<T as IsFreeze>::value() && mem::size_of::<T>() <= mem::size_of::<u64>()
}
212 changes: 67 additions & 145 deletions src/quicksort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,14 @@ pub fn stable_quicksort<T, F>(
let mut mid = 0;

if !should_do_equal_partition {
mid = <T as StablePartitionTypeImpl>::stable_partition(v, scratch, pivot, is_less);
mid = stable_partition(v, scratch, pivot, is_less);

// Fallback for non Freeze types.
should_do_equal_partition = mid == 0;
}

if should_do_equal_partition {
let mid_eq =
<T as StablePartitionTypeImpl>::stable_partition(v, scratch, pivot, &mut |a, b| {
!is_less(b, a)
});
let mid_eq = stable_partition(v, scratch, pivot, &mut |a, b| !is_less(b, a));
v = &mut v[mid_eq..];
ancestor_pivot = None;
continue;
Expand Down Expand Up @@ -187,38 +184,18 @@ where
}
}

// The manual unrolling required for good perf for integer like types has a big impact on debug
// compile times. To limit unnecessary code-gen this is put into a trait.
trait StablePartitionTypeImpl: Sized {
/// Partitions `v` into elements smaller than `pivot`, followed by elements
/// greater than or equal to `pivot`.
///
/// Returns the number of elements smaller than `pivot`.
fn stable_partition<F>(
v: &mut [Self],
scratch: &mut [MaybeUninit<Self>],
pivot_pos: usize,
is_less: &mut F,
) -> usize
where
F: FnMut(&Self, &Self) -> bool;
}

impl<T> StablePartitionTypeImpl for T {
default fn stable_partition<F>(
v: &mut [Self],
scratch: &mut [MaybeUninit<Self>],
pivot_pos: usize,
is_less: &mut F,
) -> usize
where
F: FnMut(&Self, &Self) -> bool,
{
stable_partition_default(v, scratch, pivot_pos, is_less)
}
}

fn stable_partition_default<T, F>(
/// Takes the input slice `v` and re-arranges elements such that when the call returns normally all
/// elements that compare true for `is_less(elem, pivot)` where `pivot == v[pivot_pos]` are on the
/// left side of `v` followed by the other elements, notionally considered greater or equal to
/// `pivot`.
///
/// Returns the number of elements that are compared true for `is_less(elem, pivot)`.
///
/// If `is_less` does not implement a total order the resulting order and return value are
/// unspecified. All original elements will remain in `v` and any possible modifications via
/// interior mutability will be observable. Same is true if `is_less` panics or `v.len()` exceeds
/// `scratch.len()`.
fn stable_partition<T, F>(
v: &mut [T],
scratch: &mut [MaybeUninit<T>],
pivot_pos: usize,
Expand Down Expand Up @@ -261,37 +238,65 @@ where
let mut lt_count = 0;
let mut ge_out_ptr = scratch_ptr.add(len);

for i in 0..len {
ge_out_ptr = ge_out_ptr.sub(1);

let elem_ptr = arr_ptr.add(i);
let unroll_len = if const { crate::is_int_like_type::<T>() } {
4
} else {
1 // If the optimizer is convinced it can still unroll this.
};

// This is required to
// handle types with interior mutability. See comment above for more info.
if const { crate::has_direct_interior_mutability::<T>() }
&& intrinsics::unlikely(elem_ptr as *const T == original_pivot_elem_ptr)
{
// We move the pivot in its correct place later.
if is_less(pivot, pivot) {
pivot_out_ptr = scratch_ptr.add(lt_count);
lt_count += 1;
} else {
pivot_out_ptr = ge_out_ptr.add(lt_count);
// Loop manually unrolled to ensure good performance.
// Example T == u64, on x86 LLVM unrolls this loop but not on Arm.
// And it's very perf critical so this is done manually.
// And surprisingly this can yield better code-gen and perf than the auto-unroll.
// TODO update comment to explain custom unrolling without duplicating the inner part.
let mut base_i = 0;
'outer: loop {
for unroll_i in 0..unroll_len {
let i = base_i + unroll_i;
if intrinsics::unlikely(i >= len) {
break 'outer;
}
} else {
let is_less_than_pivot = is_less(&*elem_ptr, pivot);

let dst_ptr = if is_less_than_pivot {
scratch_ptr
let elem_ptr = arr_ptr.add(i);

ge_out_ptr = ge_out_ptr.sub(1);

// This is required to
// handle types with interior mutability. See comment above for more info.
if const { crate::has_direct_interior_mutability::<T>() }
&& intrinsics::unlikely(elem_ptr as *const T == original_pivot_elem_ptr)
{
// We move the pivot in its correct place later.
if is_less(pivot, pivot) {
pivot_out_ptr = scratch_ptr.add(lt_count);
lt_count += 1;
} else {
pivot_out_ptr = ge_out_ptr.add(lt_count);
}
} else {
ge_out_ptr
};
ptr::copy_nonoverlapping(elem_ptr, dst_ptr.add(lt_count), 1);
let is_less_than_pivot = is_less(&*elem_ptr, pivot);

lt_count += is_less_than_pivot as usize;
if const { mem::size_of::<T>() <= mem::size_of::<u64>() } {
// Benchmarks show that especially on Firestorm (apple-m1) for anything at
// most the size of a u64, double storing is more efficient than conditional
// store. It is also less at risk of having the compiler generating a branch
// instead of conditional store.
ptr::copy_nonoverlapping(elem_ptr, scratch_ptr.add(lt_count), 1);
ptr::copy_nonoverlapping(elem_ptr, ge_out_ptr.add(lt_count), 1);
} else {
let dst_ptr = if is_less_than_pivot {
scratch_ptr
} else {
ge_out_ptr
};
ptr::copy_nonoverlapping(elem_ptr, dst_ptr.add(lt_count), 1);
}

lt_count += is_less_than_pivot as usize;
}
}

base_i += unroll_len;
}
// }

// Now that any possible observation of pivot has happened we copy it.
if const { has_direct_interior_mutability::<T>() } {
Expand All @@ -312,89 +317,6 @@ where
}
}

impl<T: crate::Freeze + Copy> StablePartitionTypeImpl for T {
fn stable_partition<F>(
v: &mut [Self],
scratch: &mut [MaybeUninit<Self>],
pivot_pos: usize,
is_less: &mut F,
) -> usize
where
F: FnMut(&Self, &Self) -> bool,
{
if const { mem::size_of::<T>() <= mem::size_of::<usize>() } {
let len = v.len();
let arr_ptr = v.as_mut_ptr();

if intrinsics::unlikely(scratch.len() < len || pivot_pos >= len) {
debug_assert!(false); // That's a logic bug in the implementation.
return 0;
}

// SAFETY: TODO
unsafe {
let pivot_value = ptr::read(&v[pivot_pos]);
let pivot: &T = &pivot_value;

let scratch_ptr = MaybeUninit::slice_as_mut_ptr(scratch);

// lt == less than, ge == greater or equal
let mut lt_count = 0;
let mut ge_out_ptr = scratch_ptr.add(len);

// Loop manually unrolled to ensure good performance.
// Example T == u64, on x86 LLVM unrolls this loop but not on Arm.
// And it's very perf critical so this is done manually.
// And surprisingly this can yield better code-gen and perf than the auto-unroll.
macro_rules! loop_body {
($elem_ptr:expr) => {
ge_out_ptr = ge_out_ptr.sub(1);

let elem_ptr = $elem_ptr;

let is_less_than_pivot = is_less(&*elem_ptr, pivot);

// Benchmarks show that especially on Firestorm (apple-m1) for anything at
// most the size of a u64 double storing is more efficient than conditional
// store. It is also less at risk of having the compiler generating a branch
// instead of conditional store.
ptr::copy_nonoverlapping(elem_ptr, scratch_ptr.add(lt_count), 1);
ptr::copy_nonoverlapping(elem_ptr, ge_out_ptr.add(lt_count), 1);

lt_count += is_less_than_pivot as usize;
};
}

let mut i: usize = 0;
let end = len.saturating_sub(1);

while i < end {
loop_body!(arr_ptr.add(i));
loop_body!(arr_ptr.add(i + 1));
i += 2;
}

if i != len {
loop_body!(arr_ptr.add(i));
}

// Copy all the elements that were not equal directly from swap to v.
ptr::copy_nonoverlapping(scratch_ptr, arr_ptr, lt_count);

// Copy the elements that were equal or more from the buf into v and reverse them.
let rev_buf_ptr = scratch_ptr.add(len - 1);
for i in 0..len - lt_count {
ptr::copy_nonoverlapping(rev_buf_ptr.sub(i), arr_ptr.add(lt_count + i), 1);
}

lt_count
}
} else {
stable_partition_default(v, scratch, pivot_pos, is_less)
}
}
}

// It's crucial that pivot_hole will be copied back to the input if any comparison in the
// loop panics. Because it could have changed due to interior mutability.
struct PivotGuard<T> {
Expand Down

0 comments on commit 923d106

Please sign in to comment.