diff --git a/src/drift.rs b/src/drift.rs index c52e03f..b45cd4a 100644 --- a/src/drift.rs +++ b/src/drift.rs @@ -124,7 +124,7 @@ pub fn sort bool>( // as our threshold, as we will call small_sort on any runs smaller than this. const MIN_MERGE_SLICE_LEN: usize = 32; let min_good_run_len = if eager_sort { - T::MAX_LEN_SMALL_SORT + T::SMALL_SORT_THRESHOLD } else if len <= (MIN_MERGE_SLICE_LEN * MIN_MERGE_SLICE_LEN) { MIN_MERGE_SLICE_LEN } else { @@ -253,10 +253,7 @@ fn create_run bool>( /// /// Returns the length of the run, and a bool that is false when the run /// is ascending, and true if the run strictly descending. -fn find_existing_run(v: &[T], is_less: &mut F) -> (usize, bool) -where - F: FnMut(&T, &T) -> bool, -{ +fn find_existing_run bool>(v: &[T], is_less: &mut F) -> (usize, bool) { let len = v.len(); if len < 2 { return (len, false); diff --git a/src/lib.rs b/src/lib.rs index b666383..9c2e088 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,6 +19,7 @@ use core::slice; mod drift; mod merge; +mod pivot; mod quicksort; mod smallsort; @@ -65,11 +66,7 @@ fn stable_sort bool>(v: &mut [T], mut is_less: F) { } #[inline(always)] -fn driftsort(v: &mut [T], is_less: &mut F) -where - F: FnMut(&T, &T) -> bool, - BufT: BufGuard, -{ +fn driftsort bool, BufT: BufGuard>(v: &mut [T], is_less: &mut F) { // Arrays of zero-sized types are always all-equal, and thus sorted. if T::IS_ZST { return; @@ -100,11 +97,7 @@ where // Deliberately don't inline the core logic to ensure the inlined insertion sort i-cache footprint // is minimal. #[inline(never)] -fn driftsort_main(v: &mut [T], is_less: &mut F) -where - F: FnMut(&T, &T) -> bool, - BufT: BufGuard, -{ +fn driftsort_main bool, BufT: BufGuard>(v: &mut [T], is_less: &mut F) { // Pick whichever is greater: // - alloc len elements up to MAX_FULL_ALLOC_BYTES // - alloc len / 2 elements diff --git a/src/merge.rs b/src/merge.rs index a1cbc1f..de2bf3f 100644 --- a/src/merge.rs +++ b/src/merge.rs @@ -4,10 +4,12 @@ use core::ptr; /// Merges non-decreasing runs `v[..mid]` and `v[mid..]` using `buf` as temporary storage, and /// stores the result into `v[..]`. -pub fn merge(v: &mut [T], scratch: &mut [MaybeUninit], mid: usize, is_less: &mut F) -where - F: FnMut(&T, &T) -> bool, -{ +pub fn merge bool>( + v: &mut [T], + scratch: &mut [MaybeUninit], + mid: usize, + is_less: &mut F, +) { let len = v.len(); if mid == 0 || mid >= len || scratch.len() < cmp::min(mid, len - mid) { diff --git a/src/pivot.rs b/src/pivot.rs new file mode 100644 index 0000000..6eeedda --- /dev/null +++ b/src/pivot.rs @@ -0,0 +1,84 @@ +// Recursively select a pseudomedian if above this threshold. +const PSEUDO_MEDIAN_REC_THRESHOLD: usize = 64; + +/// Selects a pivot from `v`. Algorithm taken from glidesort by Orson Peters. +/// +/// This chooses a pivot by sampling an adaptive amount of points, approximating +/// the quality of a median of sqrt(n) elements. +pub fn choose_pivot bool>(v: &[T], is_less: &mut F) -> usize { + // We use unsafe code and raw pointers here because we're dealing with + // heavy recursion. Passing safe slices around would involve a lot of + // branches and function call overhead. + + // SAFETY: a, b, c point to initialized regions of len_div_8 elements, + // satisfying median3 and median3_rec's preconditions as v_base points + // to an initialized region of n = len elements. + unsafe { + let v_base = v.as_ptr(); + let len = v.len(); + let len_div_8 = len / 8; + + let a = v_base; // [0, floor(n/8)) + let b = v_base.add(len_div_8 * 4); // [4*floor(n/8), 5*floor(n/8)) + let c = v_base.add(len_div_8 * 7); // [7*floor(n/8), 8*floor(n/8)) + + if len < PSEUDO_MEDIAN_REC_THRESHOLD { + median3(&*a, &*b, &*c, is_less).sub_ptr(v_base) + } else { + median3_rec(a, b, c, len_div_8, is_less).sub_ptr(v_base) + } + } +} + +/// Calculates an approximate median of 3 elements from sections a, b, c, or +/// recursively from an approximation of each, if they're large enough. By +/// dividing the size of each section by 8 when recursing we have logarithmic +/// recursion depth and overall sample from f(n) = 3*f(n/8) -> f(n) = +/// O(n^(log(3)/log(8))) ~= O(n^0.528) elements. +/// +/// SAFETY: a, b, c must point to the start of initialized regions of memory of +/// at least n elements. +unsafe fn median3_rec bool>( + mut a: *const T, + mut b: *const T, + mut c: *const T, + n: usize, + is_less: &mut F, +) -> *const T { + // SAFETY: a, b, c still point to initialized regions of n / 8 elements, + // by the exact same logic as in choose_pivot. + unsafe { + if n * 8 >= PSEUDO_MEDIAN_REC_THRESHOLD { + let n8 = n / 8; + a = median3_rec(a, a.add(n8 * 4), a.add(n8 * 7), n8, is_less); + b = median3_rec(b, b.add(n8 * 4), b.add(n8 * 7), n8, is_less); + c = median3_rec(c, c.add(n8 * 4), c.add(n8 * 7), n8, is_less); + } + median3(&*a, &*b, &*c, is_less) + } +} + +/// Calculates the median of 3 elements. +/// +/// SAFETY: a, b, c must be valid initialized elements. +#[inline(always)] +fn median3 bool>(a: &T, b: &T, c: &T, is_less: &mut F) -> *const T { + // Compiler tends to make this branchless when sensible, and avoids the + // third comparison when not. + let x = is_less(a, b); + let y = is_less(a, c); + if x == y { + // If x=y=0 then b, c <= a. In this case we want to return max(b, c). + // If x=y=1 then a < b, c. In this case we want to return min(b, c). + // By toggling the outcome of b < c using XOR x we get this behavior. + let z = is_less(b, c); + if z ^ x { + c + } else { + b + } + } else { + // Either c <= a < b or b <= a < c, thus a is our median. + a + } +} diff --git a/src/quicksort.rs b/src/quicksort.rs index 52b3eed..ee0cab3 100644 --- a/src/quicksort.rs +++ b/src/quicksort.rs @@ -3,28 +3,24 @@ use core::mem::{self, ManuallyDrop, MaybeUninit}; use core::ptr; use crate::has_direct_interior_mutability; +use crate::pivot::choose_pivot; use crate::smallsort::SmallSortTypeImpl; -// Recursively select a pseudomedian if above this threshold. -const PSEUDO_MEDIAN_REC_THRESHOLD: usize = 64; - /// Sorts `v` recursively using quicksort. /// /// `limit` when initialized with `c*log(v.len())` for some c ensures we do not /// overflow the stack or go quadratic. -pub fn stable_quicksort( +pub fn stable_quicksort bool>( mut v: &mut [T], scratch: &mut [MaybeUninit], mut limit: u32, mut left_ancestor_pivot: Option<&T>, is_less: &mut F, -) where - F: FnMut(&T, &T) -> bool, -{ +) { loop { let len = v.len(); - if len <= T::MAX_LEN_SMALL_SORT { + if len <= T::SMALL_SORT_THRESHOLD { T::sort_small(v, scratch, is_less); return; } @@ -79,112 +75,18 @@ pub fn stable_quicksort( } } -/// Selects a pivot from `v`. Algorithm taken from glidesort by Orson Peters. -/// -/// This chooses a pivot by sampling an adaptive amount of points, approximating -/// the quality of a median of sqrt(n) elements. -fn choose_pivot(v: &[T], is_less: &mut F) -> usize -where - F: FnMut(&T, &T) -> bool, -{ - // We use unsafe code and raw pointers here because we're dealing with - // heavy recursion. Passing safe slices around would involve a lot of - // branches and function call overhead. - - // SAFETY: a, b, c point to initialized regions of len_div_8 elements, - // satisfying median3 and median3_rec's preconditions as v_base points - // to an initialized region of n = len elements. - unsafe { - let v_base = v.as_ptr(); - let len = v.len(); - let len_div_8 = len / 8; - - let a = v_base; // [0, floor(n/8)) - let b = v_base.add(len_div_8 * 4); // [4*floor(n/8), 5*floor(n/8)) - let c = v_base.add(len_div_8 * 7); // [7*floor(n/8), 8*floor(n/8)) - - if len < PSEUDO_MEDIAN_REC_THRESHOLD { - median3(&*a, &*b, &*c, is_less).sub_ptr(v_base) - } else { - median3_rec(a, b, c, len_div_8, is_less).sub_ptr(v_base) - } - } -} - -/// Calculates an approximate median of 3 elements from sections a, b, c, or -/// recursively from an approximation of each, if they're large enough. By -/// dividing the size of each section by 8 when recursing we have logarithmic -/// recursion depth and overall sample from f(n) = 3*f(n/8) -> f(n) = -/// O(n^(log(3)/log(8))) ~= O(n^0.528) elements. -/// -/// SAFETY: a, b, c must point to the start of initialized regions of memory of -/// at least n elements. -unsafe fn median3_rec( - mut a: *const T, - mut b: *const T, - mut c: *const T, - n: usize, - is_less: &mut F, -) -> *const T -where - F: FnMut(&T, &T) -> bool, -{ - // SAFETY: a, b, c still point to initialized regions of n / 8 elements, - // by the exact same logic as in choose_pivot. - unsafe { - if n * 8 >= PSEUDO_MEDIAN_REC_THRESHOLD { - let n8 = n / 8; - a = median3_rec(a, a.add(n8 * 4), a.add(n8 * 7), n8, is_less); - b = median3_rec(b, b.add(n8 * 4), b.add(n8 * 7), n8, is_less); - c = median3_rec(c, c.add(n8 * 4), c.add(n8 * 7), n8, is_less); - } - median3(&*a, &*b, &*c, is_less) - } -} - -/// Calculates the median of 3 elements. -/// -/// SAFETY: a, b, c must be valid initialized elements. -#[inline(always)] -fn median3(a: &T, b: &T, c: &T, is_less: &mut F) -> *const T -where - F: FnMut(&T, &T) -> bool, -{ - // Compiler tends to make this branchless when sensible, and avoids the - // third comparison when not. - let x = is_less(a, b); - let y = is_less(a, c); - if x == y { - // If x=y=0 then b, c <= a. In this case we want to return max(b, c). - // If x=y=1 then a < b, c. In this case we want to return min(b, c). - // By toggling the outcome of b < c using XOR x we get this behavior. - let z = is_less(b, c); - if z ^ x { - c - } else { - b - } - } else { - // Either c <= a < b or b <= a < c, thus a is our median. - a - } -} - /// Partitions `v` using pivot `p = v[pivot_pos]` and returns the number of /// elements less than `p`. The relative order of elements that compare < p and /// those that compare >= p is preserved - it is a stable partition. /// /// If `is_less` is not a strict total order or panics, `scratch.len() < v.len()`, /// or `pivot_pos >= v.len()`, the result and `v`'s state is sound but unspecified. -fn stable_partition( +fn stable_partition bool>( v: &mut [T], scratch: &mut [MaybeUninit], pivot_pos: usize, is_less: &mut F, -) -> usize -where - F: FnMut(&T, &T) -> bool, -{ +) -> usize { let num_lt = T::partition_fill_scratch(v, scratch, pivot_pos, is_less); // SAFETY: partition_fill_scratch guarantees that scratch is initialized @@ -213,27 +115,22 @@ trait StablePartitionTypeImpl: Sized { /// Performs the same operation as [`stable_partition`], except it stores the /// permuted elements as copies in `scratch`, with the >= partition in /// reverse order. - fn partition_fill_scratch( + fn partition_fill_scratch bool>( v: &[Self], scratch: &mut [MaybeUninit], pivot_pos: usize, is_less: &mut F, - ) -> usize - where - F: FnMut(&Self, &Self) -> bool; + ) -> usize; } impl StablePartitionTypeImpl for T { /// See [`StablePartitionTypeImpl::partition_fill_scratch`]. - default fn partition_fill_scratch( + default fn partition_fill_scratch bool>( v: &[T], scratch: &mut [MaybeUninit], pivot_pos: usize, is_less: &mut F, - ) -> usize - where - F: FnMut(&Self, &Self) -> bool, - { + ) -> usize { let len = v.len(); let v_base = v.as_ptr(); let scratch_base = MaybeUninit::slice_as_mut_ptr(scratch); @@ -309,15 +206,12 @@ where (): crate::IsTrue<{ mem::size_of::() <= (mem::size_of::() * 2) }>, { /// See [`StablePartitionTypeImpl::partition_fill_scratch`]. - fn partition_fill_scratch( + fn partition_fill_scratch bool>( v: &[T], scratch: &mut [MaybeUninit], pivot_pos: usize, is_less: &mut F, - ) -> usize - where - F: FnMut(&Self, &Self) -> bool, - { + ) -> usize { let len = v.len(); let v_base = v.as_ptr(); let scratch_base = MaybeUninit::slice_as_mut_ptr(scratch); diff --git a/src/smallsort.rs b/src/smallsort.rs index 42e83e9..66cdd3f 100644 --- a/src/smallsort.rs +++ b/src/smallsort.rs @@ -11,41 +11,45 @@ use core::ptr; // Use a trait to focus code-gen on only the parts actually relevant for the type. Avoid generating // LLVM-IR for the sorting-network and median-networks for types that don't qualify. -pub(crate) trait SmallSortTypeImpl: Sized { - const MAX_LEN_SMALL_SORT: usize; +pub trait SmallSortTypeImpl: Sized { + const SMALL_SORT_THRESHOLD: usize; /// Sorts `v` using strategies optimized for small sizes. - fn sort_small(v: &mut [Self], scratch: &mut [MaybeUninit], is_less: &mut F) - where - F: FnMut(&Self, &Self) -> bool; + fn sort_small bool>( + v: &mut [Self], + scratch: &mut [MaybeUninit], + is_less: &mut F, + ); } impl SmallSortTypeImpl for T { - default const MAX_LEN_SMALL_SORT: usize = 16; + default const SMALL_SORT_THRESHOLD: usize = 16; - default fn sort_small(v: &mut [Self], _scratch: &mut [MaybeUninit], is_less: &mut F) - where - F: FnMut(&T, &T) -> bool, - { + default fn sort_small bool>( + v: &mut [T], + _scratch: &mut [MaybeUninit], + is_less: &mut F, + ) { if v.len() >= 2 { insertion_sort_shift_left(v, 1, is_less); } } } -pub(crate) const MIN_SMALL_SORT_SCRATCH_LEN: usize = i32::MAX_LEN_SMALL_SORT + 16; +pub const MIN_SMALL_SORT_SCRATCH_LEN: usize = i32::SMALL_SORT_THRESHOLD + 16; impl SmallSortTypeImpl for T where T: crate::Freeze, (): crate::IsTrue<{ mem::size_of::() <= 96 }>, { - const MAX_LEN_SMALL_SORT: usize = 20; + const SMALL_SORT_THRESHOLD: usize = 20; - fn sort_small(v: &mut [Self], scratch: &mut [MaybeUninit], is_less: &mut F) - where - F: FnMut(&T, &T) -> bool, - { + fn sort_small bool>( + v: &mut [T], + scratch: &mut [MaybeUninit], + is_less: &mut F, + ) { let len = v.len(); if len >= 2 { @@ -66,14 +70,14 @@ where // SAFETY: scratch_base is valid and has enough space. sort8_stable( v_base, - scratch_base.add(T::MAX_LEN_SMALL_SORT), + scratch_base.add(T::SMALL_SORT_THRESHOLD), scratch_base, is_less, ); sort8_stable( v_base.add(len_div_2), - scratch_base.add(T::MAX_LEN_SMALL_SORT + 8), + scratch_base.add(T::SMALL_SORT_THRESHOLD + 8), scratch_base.add(len_div_2), is_less, ); @@ -161,10 +165,7 @@ impl Drop for GapGuard { /// becomes sorted. Returns the insert position. /// Inserts `v[v.len() - 1]` into pre-sorted sequence `v[..v.len() - 1]` so that whole `v[..]` /// becomes sorted. -unsafe fn insert_tail(v: &mut [T], is_less: &mut F) -where - F: FnMut(&T, &T) -> bool, -{ +unsafe fn insert_tail bool>(v: &mut [T], is_less: &mut F) { if v.len() < 2 { intrinsics::abort(); } @@ -216,10 +217,11 @@ where } /// Sort `v` assuming `v[..offset]` is already sorted. -pub fn insertion_sort_shift_left(v: &mut [T], offset: usize, is_less: &mut F) -where - F: FnMut(&T, &T) -> bool, -{ +pub fn insertion_sort_shift_left bool>( + v: &mut [T], + offset: usize, + is_less: &mut F, +) { let len = v.len(); if offset == 0 || offset > len { @@ -238,10 +240,11 @@ where /// SAFETY: The caller MUST guarantee that `v_base` is valid for 4 reads and `dest_ptr` is valid /// for 4 writes. The result will be stored in `dst[0..4]`. -pub unsafe fn sort4_stable(v_base: *const T, dst: *mut T, is_less: &mut F) -where - F: FnMut(&T, &T) -> bool, -{ +pub unsafe fn sort4_stable bool>( + v_base: *const T, + dst: *mut T, + is_less: &mut F, +) { // By limiting select to picking pointers, we are guaranteed good cmov code-gen regardless of // type T layout. Further this only does 5 instead of 6 comparisons compared to a stable // transposition 4 element sorting-network. Also by only operating on pointers, we get optimal @@ -297,11 +300,12 @@ where /// SAFETY: The caller MUST guarantee that `v_base` is valid for 8 reads and writes, `scratch_base` /// and `dst` MUST be valid for 8 writes. The result will be stored in `dst[0..8]`. #[inline(never)] -unsafe fn sort8_stable(v_base: *mut T, scratch_base: *mut T, dst: *mut T, is_less: &mut F) -where - T: crate::Freeze, - F: FnMut(&T, &T) -> bool, -{ +unsafe fn sort8_stable bool>( + v_base: *mut T, + scratch_base: *mut T, + dst: *mut T, + is_less: &mut F, +) { // SAFETY: The caller must guarantee that scratch_base is valid for 8 writes, and that v_base is // valid for 8 reads. unsafe { @@ -316,15 +320,12 @@ where } #[inline(always)] -unsafe fn merge_up( +unsafe fn merge_up bool>( mut left_src: *const T, mut right_src: *const T, mut dst: *mut T, is_less: &mut F, -) -> (*const T, *const T, *mut T) -where - F: FnMut(&T, &T) -> bool, -{ +) -> (*const T, *const T, *mut T) { // This is a branchless merge utility function. // The equivalent code with a branch would be: // @@ -352,15 +353,12 @@ where } #[inline(always)] -unsafe fn merge_down( +unsafe fn merge_down bool>( mut left_src: *const T, mut right_src: *const T, mut dst: *mut T, is_less: &mut F, -) -> (*const T, *const T, *mut T) -where - F: FnMut(&T, &T) -> bool, -{ +) -> (*const T, *const T, *mut T) { // This is a branchless merge utility function. // The equivalent code with a branch would be: // @@ -395,11 +393,11 @@ where /// // SAFETY: the caller must guarantee that `dst` is valid for v.len() writes. // Also `v.as_ptr` and `dst` must not alias. -unsafe fn bi_directional_merge_even(v: &[T], dst: *mut T, is_less: &mut F) -where - T: crate::Freeze, - F: FnMut(&T, &T) -> bool, -{ +unsafe fn bi_directional_merge_even bool>( + v: &[T], + dst: *mut T, + is_less: &mut F, +) { // The caller must guarantee that T cannot modify itself inside is_less. // merge_up and merge_down read left and right pointers and potentially modify the stack value // they point to, if T has interior mutability. This may leave one or two potential writes to