diff --git a/src/drift.rs b/src/drift.rs index 3c27783..e19afdf 100644 --- a/src/drift.rs +++ b/src/drift.rs @@ -11,15 +11,12 @@ fn logical_merge bool>( right: DriftsortRun, is_less: &mut F, ) -> DriftsortRun { - // If one or both of the runs are sorted do a physical merge. Using quicksort to sort the - // unsorted run if present. - + // If one or both of the runs are sorted do a physical merge, using + // quicksort to sort the unsorted run if present. We also *need* to + // physically merge if the combined runs would not fit in the scratch space + // anymore (as this would mean we are no longer able to to quicksort them). let len = v.len(); - - // We *need* to physically merge if the combined runs do not fit in the scratch space anymore - // (as this would mean we are no longer able to to quicksort them). let can_fit_in_scratch = len <= scratch.len(); - if !can_fit_in_scratch || left.sorted() || right.sorted() { if !left.sorted() { crate::stable_quicksort(&mut v[..left.len()], scratch, is_less); @@ -27,7 +24,6 @@ fn logical_merge bool>( if !right.sorted() { crate::stable_quicksort(&mut v[left.len()..], scratch, is_less); } - crate::physical_merge(v, scratch, left.len(), is_less); DriftsortRun::new_sorted(len) @@ -109,18 +105,22 @@ pub fn sort bool>( eager_sort: bool, is_less: &mut F, ) { - // What's the smallest possible sub-slice that is considered a already sorted run and used for - // merging. - const MIN_MERGE_SLICE_LEN: usize = 32; - let len = v.len(); if len < 2 { return; // Removing this length check *increases* code size. } - let scale_factor = merge_tree_scale_factor(len); - let min_good_run_len = if len <= (MIN_MERGE_SLICE_LEN * MIN_MERGE_SLICE_LEN) { + // It's important to have a relatively high entry barrier for pre-sorted + // runs, as the presence of a single such run will force on average several + // merge operations and shrink the maximum quicksort size a lot. For that + // reason we use sqrt(len) as our pre-sorted run threshold, but no smaller + // than 32. When eagerly sorting we use crate::quicksort::SMALL_SORT_THRESHOLD + // as our threshold, as we will call small_sort on any runs smaller than this. + const MIN_MERGE_SLICE_LEN: usize = 32; + let min_good_run_len = if eager_sort { + crate::quicksort::SMALL_SORT_THRESHOLD + } else if len <= (MIN_MERGE_SLICE_LEN * MIN_MERGE_SLICE_LEN) { MIN_MERGE_SLICE_LEN } else { sqrt_approx(len) diff --git a/src/lib.rs b/src/lib.rs index a19e1ae..faf582c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,8 +22,6 @@ mod merge; mod quicksort; mod smallsort; -const FALLBACK_RUN_LEN: usize = 10; - /// Compactly stores the length of a run, and whether or not it is sorted. This /// can always fit in a usize because the maximum slice length is isize::MAX. #[derive(Copy, Clone)] @@ -72,36 +70,27 @@ where F: FnMut(&T, &T) -> bool, BufT: BufGuard, { - // Sorting has no meaningful behavior on zero-sized types. + // Arrays of zero-sized types are always all-equal, and thus sorted. if T::IS_ZST { return; } + // Instrumenting the standard library showed that 90+% of the calls to sort + // by rustc are either of size 0 or 1. let len = v.len(); - - // This path is critical for very small inputs. Always pick insertion sort for these inputs, - // without any other analysis. This is perf critical for small inputs, in cold code. - const MAX_LEN_ALWAYS_INSERTION_SORT: usize = 20; - - // Instrumenting the standard library showed that 90+% of the calls to sort by rustc are either - // of size 0 or 1. Make this path extra fast by assuming the branch is likely. if intrinsics::likely(len < 2) { return; } - // It's important to differentiate between small-sort performance for small slices and - // small-sort performance sorting small sub-slices as part of the main quicksort loop. For the - // former, testing showed that the representative benchmarks for real-world performance are cold - // CPU state and not single-size hot benchmarks. For the latter the CPU will call them many - // times, so hot benchmarks are fine and more realistic. And it's worth it to optimize sorting - // small sub-slices with more sophisticated solutions than insertion sort. - + // More advanced sorting methods than insertion sort are faster if called in + // a hot loop for small inputs, but for general-purpose code the small + // binary size of insertion sort is more important. The instruction cache in + // modern processors is very valuable, and for a single sort call in general + // purpose code any gains from an advanced method are cancelled by icache + // misses during the sort, and thrashing the icache for surrounding code. + const MAX_LEN_ALWAYS_INSERTION_SORT: usize = 20; if intrinsics::likely(len <= MAX_LEN_ALWAYS_INSERTION_SORT) { - // More specialized and faster options, extending the range of allocation free sorting - // are possible but come at a great cost of additional code, which is problematic for - // compile-times. smallsort::insertion_sort_shift_left(v, 1, is_less); - return; } @@ -114,27 +103,18 @@ where F: FnMut(&T, &T) -> bool, BufT: BufGuard, { - // Allocating len instead of len / 2 allows the quicksort to work on the full size, which can - // give speedups especially for low cardinality inputs where common values are filtered out only - // once, instead of twice. And it allows bi-directional merging the full input. However to - // reduce peak memory usage for large inputs, fall back to allocating len / 2 if a certain - // threshold is passed. - const MAX_FULL_ALLOC_BYTES: usize = 8_000_000; // 8MB - - let len = v.len(); - // Pick whichever is greater: - // - // - alloc n up to MAX_FULL_ALLOC_BYTES - // - alloc n / 2 - // - // This serves to make the impact and performance cliff when going above the threshold less - // severe than immediately switching to len / 2. + // - alloc len elements up to MAX_FULL_ALLOC_BYTES + // - alloc len / 2 elements + // This allows us to use the most performant algorithms for small-medium + // sized inputs while scaling down to len / 2 for larger inputs. We need at + // least len / 2 for our stable merging routine. + const MAX_FULL_ALLOC_BYTES: usize = 8_000_000; + let len = v.len(); let full_alloc_size = cmp::min(len, MAX_FULL_ALLOC_BYTES / mem::size_of::()); let alloc_size = cmp::max(len / 2, full_alloc_size); let mut buf = BufT::with_capacity(alloc_size); - let scratch_slice = unsafe { slice::from_raw_parts_mut(buf.mut_ptr() as *mut MaybeUninit, buf.capacity()) }; @@ -163,84 +143,65 @@ fn stable_quicksort bool>( crate::quicksort::stable_quicksort(v, scratch, limit, None, is_less); } -/// Create a new logical run, that is either sorted or unsorted. +/// Creates a new logical run. +/// +/// A logical run can either be sorted or unsorted. If there is a pre-existing +/// run of length min_good_run_len (or longer) starting at v[0] we find and +/// return it, otherwise we return a run of length min_good_run_len that is +/// eagerly sorted when eager_sort is true, and left unsorted otherwise. fn create_run bool>( v: &mut [T], - mut min_good_run_len: usize, + min_good_run_len: usize, eager_sort: bool, is_less: &mut F, ) -> DriftsortRun { - // FIXME: run detection. - - let len = v.len(); - - let (streak_end, was_reversed) = find_streak(v, is_less); - - if eager_sort { - min_good_run_len = FALLBACK_RUN_LEN; - } - - // It's important to have a relatively high entry barrier for pre-sorted runs, as the presence - // of a single such run will force on average several merge operations and shrink the max - // quicksort size a lot. Which impact low-cardinality filtering performance. - if streak_end >= min_good_run_len { + let (run_len, was_reversed) = find_existing_run(v, is_less); + if run_len >= min_good_run_len { if was_reversed { - v[..streak_end].reverse(); + v[..run_len].reverse(); } - - DriftsortRun::new_sorted(streak_end) + DriftsortRun::new_sorted(run_len) } else { - if !eager_sort { - // min_good_run_len serves dual duty here, if no streak was found, create a relatively - // large unsorted run to avoid calling find_streak all the time. This also puts a limit - // on how many logical merges have to be done, but this plays a minor role performance - // wise. - DriftsortRun::new_unsorted(cmp::min(min_good_run_len, len)) + let new_run_len = cmp::min(min_good_run_len, v.len()); + if eager_sort { + smallsort::sort_small(&mut v[..new_run_len], is_less); + DriftsortRun::new_sorted(new_run_len) } else { - // We are not allowed to generate unsorted sequences in this mode. This mode is used as - // fallback algorithm for quicksort. Essentially falling back to merge sort. - let run_end = cmp::min(crate::quicksort::SMALL_SORT_THRESHOLD, len); - smallsort::sort_small(&mut v[..run_end], is_less); - - DriftsortRun::new_sorted(run_end) + DriftsortRun::new_unsorted(new_run_len) } } } -/// Finds a streak of presorted elements starting at the beginning of the slice. Returns the first -/// value that is not part of said streak, and a bool denoting wether the streak was reversed. -/// Streaks can be increasing or decreasing. -fn find_streak(v: &[T], is_less: &mut F) -> (usize, bool) +/// Finds a run of sorted elements starting at the beginning of the slice. +/// +/// Returns the length of the run, and a bool that is false when the run +/// is ascending, and true if the run strictly descending. +fn find_existing_run(v: &[T], is_less: &mut F) -> (usize, bool) where F: FnMut(&T, &T) -> bool, { let len = v.len(); - if len < 2 { return (len, false); } - let mut end = 2; - - // SAFETY: See below specific. unsafe { // SAFETY: We checked that len >= 2, so 0 and 1 are valid indices. - let assume_reverse = is_less(v.get_unchecked(1), v.get_unchecked(0)); - - // SAFETY: We know end >= 2 and check end < len. - // From that follows that accessing v at end and end - 1 is safe. - if assume_reverse { - while end < len && is_less(v.get_unchecked(end), v.get_unchecked(end - 1)) { - end += 1; + // This also means that run_len < len implies run_len and + // run_len - 1 are valid indices as well. + let mut run_len = 2; + let strictly_descending = is_less(v.get_unchecked(1), v.get_unchecked(0)); + if strictly_descending { + while run_len < len && is_less(v.get_unchecked(run_len), v.get_unchecked(run_len - 1)) { + run_len += 1; } - - (end, true) } else { - while end < len && !is_less(v.get_unchecked(end), v.get_unchecked(end - 1)) { - end += 1; + while run_len < len && !is_less(v.get_unchecked(run_len), v.get_unchecked(run_len - 1)) + { + run_len += 1; } - (end, false) } + (run_len, strictly_descending) } } diff --git a/src/quicksort.rs b/src/quicksort.rs index 8edbb40..4c5e7f3 100644 --- a/src/quicksort.rs +++ b/src/quicksort.rs @@ -12,23 +12,17 @@ const PSEUDO_MEDIAN_REC_THRESHOLD: usize = 64; /// Sorts `v` recursively using quicksort. /// -/// `limit` ensures we do not stack overflow and do not go quadratic. If reached -/// we switch to purely mergesort by eager sorting. +/// `limit` when initialized with `c*log(v.len())` for some c ensures we do not +/// overflow the stack or go quadratic. pub fn stable_quicksort( mut v: &mut [T], scratch: &mut [MaybeUninit], mut limit: u32, - mut ancestor_pivot: Option<&T>, + mut left_ancestor_pivot: Option<&T>, is_less: &mut F, ) where F: FnMut(&T, &T) -> bool, { - // To improve filtering out of common values with equal partition, we remember the - // ancestor_pivot and use that to compare it to the next pivot selection. Because we can't move - // the relative position of the pivot in a stable sort and subsequent partitioning may change - // the position, its easier to simply make a copy of the pivot value and use that for further - // comparisons. - loop { if v.len() <= SMALL_SORT_THRESHOLD { crate::smallsort::sort_small(v, is_less); @@ -41,54 +35,42 @@ pub fn stable_quicksort( } limit -= 1; - let pivot = choose_pivot(v, is_less); - - let mut should_do_equal_partition = false; - - // If the chosen pivot is equal to the ancestor_pivot, then it's the smallest element in the - // slice. Partition the slice into elements equal to and elements greater than the pivot. - // This case is usually hit when the slice contains many duplicate elements. - if let Some(a_pivot) = ancestor_pivot { - should_do_equal_partition = !is_less(a_pivot, &v[pivot]); + // SAFETY: We only access the temporary copy for Freeze types, otherwise + // self-modifications via `is_less` would not be observed and this would + // be unsound. Our temporary copy does not escape this scope. + let pivot_idx = choose_pivot(v, is_less); + let pivot_copy = unsafe { ManuallyDrop::new(ptr::read(&v[pivot_idx])) }; + let pivot_ref = (!has_direct_interior_mutability::()).then_some(&*pivot_copy); + + // We choose a pivot, and check if this pivot is equal to our left + // ancestor. If true, we do a partition putting equal elements on the + // left and do not recurse on it. This gives O(n log k) sorting for k + // distinct values, a strategy borrowed from pdqsort. For types with + // interior mutability we can't soundly create a temporary copy of the + // ancestor pivot, and use left_partition_len == 0 as our method for + // detecting when we re-use a pivot, which means we do at most three + // partition operations with pivot p instead of the optimal two. + let mut perform_equal_partition = false; + if let Some(la_pivot) = left_ancestor_pivot { + perform_equal_partition = !is_less(la_pivot, &v[pivot_idx]); } - // SAFETY: See we only use this value for Feeze types, otherwise self-modifications via - // `is_less` would not be observed and this would be unsound. - // - // It's important we do this after we picked the pivot and checked it against the - // ancestor_pivot, but before we change v again by partitioning. - let pivot_copy = unsafe { ManuallyDrop::new(ptr::read(&v[pivot])) }; - - let mut mid = 0; - - if !should_do_equal_partition { - mid = stable_partition(v, scratch, pivot, is_less); - - // Fallback for non Freeze types. - should_do_equal_partition = mid == 0; + let mut left_partition_len = 0; + if !perform_equal_partition { + left_partition_len = stable_partition(v, scratch, pivot_idx, is_less); + perform_equal_partition = left_partition_len == 0; } - if should_do_equal_partition { - let mid_eq = stable_partition(v, scratch, pivot, &mut |a, b| !is_less(b, a)); + if perform_equal_partition { + let mid_eq = stable_partition(v, scratch, pivot_idx, &mut |a, b| !is_less(b, a)); v = &mut v[mid_eq..]; - ancestor_pivot = None; + left_ancestor_pivot = None; continue; } - let (left, right) = v.split_at_mut(mid); - - let new_ancestor_pivot = if const { !has_direct_interior_mutability::() } { - // SAFETY: pivot_copy is valid and initialized, lives on the stack and as a consequence - // outlives the immediate call to stable_quicksort. - unsafe { Some(&*(&pivot_copy as *const ManuallyDrop as *const T)) } - } else { - None - }; - - // Processing right side with recursion. - stable_quicksort(right, scratch, limit, new_ancestor_pivot, is_less); - - // Processing left side with next loop iteration. + // Process left side with the next loop iter, right side with recursion. + let (left, right) = v.split_at_mut(left_partition_len); + stable_quicksort(right, scratch, limit, pivot_ref, is_less); v = left; } } @@ -101,21 +83,20 @@ fn choose_pivot(v: &[T], is_less: &mut F) -> usize where F: FnMut(&T, &T) -> bool, { - let len = v.len(); + // We use unsafe code and raw pointers here because we're dealing with + // heavy recursion. Passing safe slices around would involve a lot of + // branches and function call overhead. - // SAFETY: The pointer operations are guaranteed to be in-bounds no matter the len of `v`. From - // which follows the calls to median3 and median3_rec are provided with pointers to valid - // elements and thus safe. + // SAFETY: a, b, c point to initialized regions of len_div_8 elements, + // satisfying median3 and median3_rec's preconditions as arr_ptr points + // to an initialized region of n = len elements. unsafe { - // We use unsafe code and raw pointers here because we're dealing with - // heavy recursion. Passing safe slices around would involve a lot of - // branches and function call overhead. let arr_ptr = v.as_ptr(); - + let len = v.len(); let len_div_8 = len / 8; - let a = arr_ptr; - let b = arr_ptr.add(len_div_8 * 4); - let c = arr_ptr.add(len_div_8 * 7); + let a = arr_ptr; // [0, floor(n/8)) + let b = arr_ptr.add(len_div_8 * 4); // [4*floor(n/8), 5*floor(n/8)) + let c = arr_ptr.add(len_div_8 * 7); // [7*floor(n/8), 8*floor(n/8)) if len < PSEUDO_MEDIAN_REC_THRESHOLD { median3(a, b, c, is_less).sub_ptr(arr_ptr) @@ -125,10 +106,11 @@ where } } -/// Calculates an approximate median of 3 elements from sections a, b, c, or recursively from an -/// approximation of each, if they're large enough. By dividing the size of each section by 8 when -/// recursing we have logarithmic recursion depth and overall sample from -/// f(n) = 3*f(n/8) -> f(n) = O(n^(log(3)/log(8))) ~= O(n^0.528) elements. +/// Calculates an approximate median of 3 elements from sections a, b, c, or +/// recursively from an approximation of each, if they're large enough. By +/// dividing the size of each section by 8 when recursing we have logarithmic +/// recursion depth and overall sample from f(n) = 3*f(n/8) -> f(n) = +/// O(n^(log(3)/log(8))) ~= O(n^0.528) elements. /// /// SAFETY: a, b, c must point to the start of initialized regions of memory of /// at least n elements. @@ -143,7 +125,8 @@ unsafe fn median3_rec( where F: FnMut(&T, &T) -> bool, { - // SAFETY: See function safety description. + // SAFETY: a, b, c still point to initialized regions of n / 8 elements, + // by the exact same logic as in choose_pivot. unsafe { if n * 8 >= PSEUDO_MEDIAN_REC_THRESHOLD { let n8 = n / 8; @@ -173,7 +156,6 @@ where // If x=y=1 then a < b, c. In this case we want to return min(b, c). // By toggling the outcome of b < c using XOR x we get this behavior. let z = is_less(&*b, &*c); - if z ^ x { c } else { @@ -186,17 +168,12 @@ where } } -/// Takes the input slice `v` and re-arranges elements such that when the call returns normally -/// all elements that compare true for `is_less(elem, pivot)` where `pivot == v[pivot_pos]` are -/// on the left side of `v` followed by the other elements, notionally considered greater or -/// equal to `pivot`. +/// Partitions `v` using pivot `p = v[pivot_pos]` and returns the number of +/// elements less than `p`. The relative order of elements that compare < p and +/// those that compare >= p is preserved - it is a stable partition. /// -/// Returns the number of elements that are compared true for `is_less(elem, pivot)`. -/// -/// If `is_less` does not implement a total order the resulting order and return value are -/// unspecified. All original elements will remain in `v` and any possible modifications via -/// interior mutability will be observable. Same is true if `is_less` panics or `v.len()` -/// exceeds `scratch.len()`. +/// If `is_less` is not a strict total order or panics, `scratch.len() < v.len()`, +/// or `pivot_pos >= v.len()`, the result and `v`'s state is sound but unspecified. fn stable_partition( v: &mut [T], scratch: &mut [MaybeUninit], @@ -206,54 +183,38 @@ fn stable_partition( where F: FnMut(&T, &T) -> bool, { - let len = v.len(); - let arr_ptr = v.as_mut_ptr(); + let num_lt = T::partition_fill_scratch(v, scratch, pivot_pos, is_less); - if intrinsics::unlikely(scratch.len() < len || pivot_pos >= len) { - debug_assert!(false); // That's a logic bug in the implementation. - return 0; - } - - let scratch_ptr = MaybeUninit::slice_as_mut_ptr(scratch); - - // SAFETY: We checked that `pivot_pos` is in-bounds and that `scratch` is valid for `len` - // writes, fulfilling the safety contract of partition_fill_scratch. Assuming - // partition_fill_scratch works as documented `scratch` should hold valid elements that observed - // all possible changes to them, and can then be copied back into `v`. + // SAFETY: partition_fill_scratch guarantees that scratch is initialized + // with a permuted copy of `v`, and that num_lt <= v.len(). Copying + // scratch[0..num_lt] and scratch[num_lt..v.len()] back is thus + // sound, as the values in scratch will never be read again, meaning our + // copies semantically act as moves, permuting `v`. unsafe { - // We can just use the value inside the slice and avoid a drop guard around a stack copy - // of the value, because we only write into scratch during the scan loop. This - // simplifies the code and shows no perf difference. - let pivot_ptr = arr_ptr.add(pivot_pos); + let len = v.len(); + let arr_ptr = v.as_mut_ptr(); + let scratch_ptr = MaybeUninit::slice_as_mut_ptr(scratch); - let lt_count = T::partition_fill_scratch(arr_ptr, len, scratch_ptr, pivot_ptr, is_less); + // Copy all the elements < p directly from swap to v. + ptr::copy_nonoverlapping(scratch_ptr, arr_ptr, num_lt); - // Copy all the elements that were not equal directly from swap to v. - ptr::copy_nonoverlapping(scratch_ptr, arr_ptr, lt_count); - - // Copy the elements that were equal or more from the buf into v and reverse them. - let rev_buf_ptr = scratch_ptr.add(len - 1); - for i in 0..len - lt_count { - ptr::copy_nonoverlapping(rev_buf_ptr.sub(i), arr_ptr.add(lt_count + i), 1); + // Copy the elements >= p in reverse order. + for i in 0..len - num_lt { + ptr::copy_nonoverlapping(scratch_ptr.add(len - 1 - i), arr_ptr.add(num_lt + i), 1); } - lt_count + num_lt } } trait StablePartitionTypeImpl: Sized { - /// Takes a slice of `len` pointed to by `arr_ptr` and fills `scratch_ptr` with a partitioned - /// copy of the values according to `is_less`. - /// - /// Example [05162738] -> [01238765] - /// - /// SAFETY: The caller MUST ensure that `arr_ptr` points to a valid slice of `len` elements and - /// that `scratch_ptr` is valid for `len` writes. - unsafe fn partition_fill_scratch( - arr_ptr: *mut Self, - len: usize, - scratch_ptr: *mut Self, - pivot_ptr: *const Self, + /// Performs the same operation as [`stable_partition`], except it stores the + /// permuted elements as copies in `scratch`, with the >= partition in + /// reverse order. + fn partition_fill_scratch( + v: &[Self], + scratch: &mut [MaybeUninit], + pivot_pos: usize, is_less: &mut F, ) -> usize where @@ -262,142 +223,167 @@ trait StablePartitionTypeImpl: Sized { impl StablePartitionTypeImpl for T { /// See [`StablePartitionTypeImpl::partition_fill_scratch`]. - default unsafe fn partition_fill_scratch( - arr_ptr: *mut Self, - len: usize, - scratch_ptr: *mut Self, - pivot_ptr: *const Self, + default fn partition_fill_scratch( + v: &[T], + scratch: &mut [MaybeUninit], + pivot_pos: usize, is_less: &mut F, ) -> usize where F: FnMut(&Self, &Self) -> bool, { - // We need to take special care of types with interior mutability. `is_less` can modify the - // values it is provided. For example if `pivot_ptr` points to an element in the middle. A - // copy of the backing element would be written into the scratch space, and later - // modifications to the element behind `pivot_ptr` would be missed by subsequent calls to - // `is_less(&*elem_ptr, &*pivot_ptr)`. This can quickly lead to UB, e.g. - // `Mutex>>` could miss an update where the `Option` is set to `None` - // which would cause a double free. - - // SAFETY: The element access is arr_ptr + i, where i < len, which makes it proven - // in-bounds, assuming the caller upholds the function safety contract. The two output - // pointers `scratch_ptr` and `ge_out_ptr` each point to a unique location within the range - // of `scratch_ptr`, and the combination of always doing decrementing `ge_out_ptr` and - // conditionally incrementing `lt_count` ensures that every location of `scratch_ptr` will - // be written. If `is_less` panics, only un-observed copies were written into the scratch - // space. - unsafe { - let mut pivot_out_ptr = ptr::null_mut(); + let len = v.len(); + let arr_ptr = v.as_ptr(); + let scratch_ptr = MaybeUninit::slice_as_mut_ptr(scratch); - // lt == less than, ge == greater or equal - let mut lt_count = 0; - let mut ge_out_ptr = scratch_ptr.add(len); + if intrinsics::unlikely(scratch.len() < len || pivot_pos >= len) { + core::intrinsics::abort() + } + unsafe { + // Abbreviations: lt == less than, ge == greater or equal. + // + // SAFETY: we checked that pivot_pos is in-bounds above, and that + // scratch has length at least len. As we do binary classification + // into lt or ge, the invariant num_lt + num_ge = i always holds at + // the start of each iteration. For micro-optimization reasons we + // write i - num_lt instead of num_gt. Since num_lt increases by at + // most 1 each iteration and since i < len, this proves our + // destination indices num_lt and len - 1 - num_ge stay + // in-bounds, and are never equal except the final iteration when + // num_lt = len - 1 - (len - 1 - num_lt) = len - 1 - num_ge. + // We write one different element to scratch each iteration thus + // scratch[0..len] will be initialized with a permutation of v. + // + // Should a panic occur, the copies in the scratch space are simply + // forgotten - even with interior mutability all data is still in v. + let pivot = arr_ptr.add(pivot_pos); + let mut pivot_in_scratch = ptr::null_mut(); + let mut num_lt = 0; + let mut scratch_rev = scratch_ptr.add(len); for i in 0..len { - let elem_ptr = arr_ptr.add(i); - ge_out_ptr = ge_out_ptr.sub(1); + let scan = arr_ptr.add(i); + scratch_rev = scratch_rev.sub(1); - let is_less_than_pivot = is_less(&*elem_ptr, &*pivot_ptr); - - let dst_ptr_base = if is_less_than_pivot { - scratch_ptr + let is_less_than_pivot = is_less(&*scan, &*pivot); + let dst = if is_less_than_pivot { + scratch_ptr.add(num_lt) // i + num_lt } else { - ge_out_ptr + scratch_rev.add(num_lt) // len - (i + 1) + num_lt = len - 1 - num_ge }; - let dst_ptr = dst_ptr_base.add(lt_count); - - ptr::copy_nonoverlapping(elem_ptr, dst_ptr, 1); + // Save pivot location in scratch for later. if const { crate::has_direct_interior_mutability::() } - && intrinsics::unlikely(elem_ptr as *const T == pivot_ptr) + && intrinsics::unlikely(scan as *const T == pivot) { - pivot_out_ptr = dst_ptr; + pivot_in_scratch = dst; } - lt_count += is_less_than_pivot as usize; + ptr::copy_nonoverlapping(scan, dst, 1); + num_lt += is_less_than_pivot as usize; } + // SAFETY: if T has interior mutability our copy in scratch can be + // outdated, update it. if const { crate::has_direct_interior_mutability::() } { - ptr::copy_nonoverlapping(pivot_ptr, pivot_out_ptr, 1); + ptr::copy_nonoverlapping(pivot, pivot_in_scratch, 1); } - lt_count + num_lt } } } -/// Specialization for int like types. +/// Specialization for small types, through traits to not invoke compile time +/// penalties for loop unrolling when not used. +/// +/// Benchmarks show that for small types simply storing to *both* potential +/// destinations is more efficient than a conditional store. It is also less at +/// risk of having the compiler generating a branch instead of conditional +/// store. And explicit loop unrolling is also often very beneficial. impl StablePartitionTypeImpl for T where - T: crate::Freeze + Copy, + T: Copy, (): crate::IsTrue<{ mem::size_of::() <= (mem::size_of::() * 2) }>, { /// See [`StablePartitionTypeImpl::partition_fill_scratch`]. - unsafe fn partition_fill_scratch( - arr_ptr: *mut Self, - len: usize, - scratch_ptr: *mut Self, - pivot_ptr: *const Self, + fn partition_fill_scratch( + v: &[T], + scratch: &mut [MaybeUninit], + pivot_pos: usize, is_less: &mut F, ) -> usize where F: FnMut(&Self, &Self) -> bool, { - // Partitioning loop manually unrolled to ensure good performance. Example T == u64, on x86 - // LLVM unrolls this loop but not on Arm. A compile time fixed size loop as based on - // `unroll_len` is reliably unrolled by all backends. And if `unroll_len` is `1` the inner - // loop can trivially be removed. - // - // The scheme used to unroll is somewhat weird, and focused on avoiding multi-instantiation - // of the inner loop part, which can have large effects on compile-time for non integer like - // types. - // - // Benchmarks show that for any Type of at most 16 bytes, double storing is more efficient - // than conditional store, especially on Firestorm (apple-m1). It is also less at risk of - // having the compiler generating a branch instead of conditional store. - - // SAFETY: The element access is arr_ptr + i, where i < len, which makes it proven - // in-bounds, assuming the caller upholds the function safety contract. The two output - // pointers `scratch_ptr` and `ge_out_ptr` each point to a unique location within the range - // of `scratch_ptr`, and the combination of always doing decrementing `ge_out_ptr` and - // conditionally incrementing `lt_count` ensures that every location of `scratch_ptr` will - // be written. If `is_less` panics, only un-observed copies were written into the scratch - // space. - unsafe { - const UNROLL_LEN: usize = 4; + let len = v.len(); + let arr_ptr = v.as_ptr(); + let scratch_ptr = MaybeUninit::slice_as_mut_ptr(scratch); - // lt == less than, ge == greater or equal - let mut lt_count = 0; - let mut ge_out_ptr = scratch_ptr.add(len); + if intrinsics::unlikely(scratch.len() < len || pivot_pos >= len) { + core::intrinsics::abort() + } + unsafe { + // SAFETY: exactly the same invariants and logic as the + // non-specialized impl. The conditional store being replaced by a + // double copy changes nothing, on all but the final iteration the + // bad write will simply be overwritten by a later iteration, and on + // the final iteration we write to the same index twice. And we do + // naive loop unrolling where the exact same loop body is just + // repeated. + let pivot = arr_ptr.add(pivot_pos); + let mut pivot_in_scratch = ptr::null_mut(); + let mut num_lt = 0; + let mut scratch_rev = scratch_ptr.add(len); macro_rules! loop_body { - ($elem_ptr:expr) => { - let elem_ptr = $elem_ptr; - ge_out_ptr = ge_out_ptr.sub(1); - - let is_less_than_pivot = is_less(&*elem_ptr, &*pivot_ptr); - - ptr::copy_nonoverlapping(elem_ptr, scratch_ptr.add(lt_count), 1); - ptr::copy_nonoverlapping(elem_ptr, ge_out_ptr.add(lt_count), 1); - - lt_count += is_less_than_pivot as usize; + ($i:expr) => { + let scan = arr_ptr.add($i); + scratch_rev = scratch_rev.sub(1); + + let is_less_than_pivot = is_less(&*scan, &*pivot); + ptr::copy_nonoverlapping(scan, scratch_ptr.add(num_lt), 1); + ptr::copy_nonoverlapping(scan, scratch_rev.add(num_lt), 1); + + // Save pivot location in scratch for later. + if const { crate::has_direct_interior_mutability::() } + && intrinsics::unlikely(scan as *const T == pivot) + { + pivot_in_scratch = if is_less_than_pivot { + scratch_ptr.add(num_lt) // i + num_lt + } else { + scratch_rev.add(num_lt) // len - (i + 1) + num_lt = len - 1 - num_ge + }; + } + + num_lt += is_less_than_pivot as usize; }; } + /// To ensure good performance across platforms we explicitly unroll using a + /// fixed-size inner loop. We do not simply call the loop body multiple times as + /// this increases compile times significantly more, and the compiler unrolls + /// a fixed loop just as well, if it is sensible. + const UNROLL_LEN: usize = 4; let mut offset = 0; for _ in 0..(len / UNROLL_LEN) { for unroll_i in 0..UNROLL_LEN { - loop_body!(arr_ptr.add(offset + unroll_i)); + loop_body!(offset + unroll_i); } offset += UNROLL_LEN; } for i in 0..(len % UNROLL_LEN) { - loop_body!(arr_ptr.add(offset + i)); + loop_body!(offset + i); + } + + // SAFETY: if T has interior mutability our copy in scratch can be + // outdated, update it. + if const { crate::has_direct_interior_mutability::() } { + ptr::copy_nonoverlapping(pivot, pivot_in_scratch, 1); } - lt_count + num_lt } } } diff --git a/src/smallsort.rs b/src/smallsort.rs index 59a43c8..033accb 100644 --- a/src/smallsort.rs +++ b/src/smallsort.rs @@ -1,6 +1,13 @@ use core::mem::{self, ManuallyDrop, MaybeUninit}; use core::ptr; +// It's important to differentiate between small-sort performance for small slices and +// small-sort performance sorting small sub-slices as part of the main quicksort loop. For the +// former, testing showed that the representative benchmarks for real-world performance are cold +// CPU state and not single-size hot benchmarks. For the latter the CPU will call them many +// times, so hot benchmarks are fine and more realistic. And it's worth it to optimize sorting +// small sub-slices with more sophisticated solutions than insertion sort. + /// Sorts `v` using strategies optimized for small sizes. pub fn sort_small(v: &mut [T], is_less: &mut F) where