Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize compile time #27

Merged
merged 4 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/drift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,10 +253,7 @@ fn create_run<T, F: FnMut(&T, &T) -> bool>(
///
/// Returns the length of the run, and a bool that is false when the run
/// is ascending, and true if the run strictly descending.
fn find_existing_run<T, F>(v: &[T], is_less: &mut F) -> (usize, bool)
where
F: FnMut(&T, &T) -> bool,
{
fn find_existing_run<T, F: FnMut(&T, &T) -> bool>(v: &[T], is_less: &mut F) -> (usize, bool) {
let len = v.len();
if len < 2 {
return (len, false);
Expand Down
13 changes: 3 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use core::slice;

mod drift;
mod merge;
mod pivot;
mod quicksort;
mod smallsort;

Expand Down Expand Up @@ -65,11 +66,7 @@ fn stable_sort<T, F: FnMut(&T, &T) -> bool>(v: &mut [T], mut is_less: F) {
}

#[inline(always)]
fn driftsort<T, F, BufT>(v: &mut [T], is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
BufT: BufGuard<T>,
{
fn driftsort<T, F: FnMut(&T, &T) -> bool, BufT: BufGuard<T>>(v: &mut [T], is_less: &mut F) {
// Arrays of zero-sized types are always all-equal, and thus sorted.
if T::IS_ZST {
return;
Expand Down Expand Up @@ -100,11 +97,7 @@ where
// Deliberately don't inline the core logic to ensure the inlined insertion sort i-cache footprint
// is minimal.
#[inline(never)]
fn driftsort_main<T, F, BufT>(v: &mut [T], is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
BufT: BufGuard<T>,
{
fn driftsort_main<T, F: FnMut(&T, &T) -> bool, BufT: BufGuard<T>>(v: &mut [T], is_less: &mut F) {
// Pick whichever is greater:
// - alloc len elements up to MAX_FULL_ALLOC_BYTES
// - alloc len / 2 elements
Expand Down
10 changes: 6 additions & 4 deletions src/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ use core::ptr;

/// Merges non-decreasing runs `v[..mid]` and `v[mid..]` using `buf` as temporary storage, and
/// stores the result into `v[..]`.
pub fn merge<T, F>(v: &mut [T], scratch: &mut [MaybeUninit<T>], mid: usize, is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
pub fn merge<T, F: FnMut(&T, &T) -> bool>(
v: &mut [T],
scratch: &mut [MaybeUninit<T>],
mid: usize,
is_less: &mut F,
) {
let len = v.len();

if mid == 0 || mid >= len || scratch.len() < cmp::min(mid, len - mid) {
Expand Down
84 changes: 84 additions & 0 deletions src/pivot.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Recursively select a pseudomedian if above this threshold.
Voultapher marked this conversation as resolved.
Show resolved Hide resolved
const PSEUDO_MEDIAN_REC_THRESHOLD: usize = 64;

/// Selects a pivot from `v`. Algorithm taken from glidesort by Orson Peters.
///
/// This chooses a pivot by sampling an adaptive amount of points, approximating
/// the quality of a median of sqrt(n) elements.
pub fn choose_pivot<T, F: FnMut(&T, &T) -> bool>(v: &[T], is_less: &mut F) -> usize {
// We use unsafe code and raw pointers here because we're dealing with
// heavy recursion. Passing safe slices around would involve a lot of
// branches and function call overhead.

// SAFETY: a, b, c point to initialized regions of len_div_8 elements,
// satisfying median3 and median3_rec's preconditions as v_base points
// to an initialized region of n = len elements.
unsafe {
let v_base = v.as_ptr();
let len = v.len();
let len_div_8 = len / 8;

let a = v_base; // [0, floor(n/8))
let b = v_base.add(len_div_8 * 4); // [4*floor(n/8), 5*floor(n/8))
let c = v_base.add(len_div_8 * 7); // [7*floor(n/8), 8*floor(n/8))

if len < PSEUDO_MEDIAN_REC_THRESHOLD {
median3(&*a, &*b, &*c, is_less).sub_ptr(v_base)
} else {
median3_rec(a, b, c, len_div_8, is_less).sub_ptr(v_base)
}
}
}

/// Calculates an approximate median of 3 elements from sections a, b, c, or
/// recursively from an approximation of each, if they're large enough. By
/// dividing the size of each section by 8 when recursing we have logarithmic
/// recursion depth and overall sample from f(n) = 3*f(n/8) -> f(n) =
/// O(n^(log(3)/log(8))) ~= O(n^0.528) elements.
///
/// SAFETY: a, b, c must point to the start of initialized regions of memory of
/// at least n elements.
unsafe fn median3_rec<T, F: FnMut(&T, &T) -> bool>(
mut a: *const T,
mut b: *const T,
mut c: *const T,
n: usize,
is_less: &mut F,
) -> *const T {
// SAFETY: a, b, c still point to initialized regions of n / 8 elements,
// by the exact same logic as in choose_pivot.
unsafe {
if n * 8 >= PSEUDO_MEDIAN_REC_THRESHOLD {
let n8 = n / 8;
a = median3_rec(a, a.add(n8 * 4), a.add(n8 * 7), n8, is_less);
b = median3_rec(b, b.add(n8 * 4), b.add(n8 * 7), n8, is_less);
c = median3_rec(c, c.add(n8 * 4), c.add(n8 * 7), n8, is_less);
}
median3(&*a, &*b, &*c, is_less)
}
}

/// Calculates the median of 3 elements.
///
/// SAFETY: a, b, c must be valid initialized elements.
#[inline(always)]
fn median3<T, F: FnMut(&T, &T) -> bool>(a: &T, b: &T, c: &T, is_less: &mut F) -> *const T {
// Compiler tends to make this branchless when sensible, and avoids the
// third comparison when not.
let x = is_less(a, b);
let y = is_less(a, c);
if x == y {
// If x=y=0 then b, c <= a. In this case we want to return max(b, c).
// If x=y=1 then a < b, c. In this case we want to return min(b, c).
// By toggling the outcome of b < c using XOR x we get this behavior.
let z = is_less(b, c);
if z ^ x {
c
} else {
b
}
} else {
// Either c <= a < b or b <= a < c, thus a is our median.
a
}
}
128 changes: 11 additions & 117 deletions src/quicksort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,20 @@ use core::mem::{self, ManuallyDrop, MaybeUninit};
use core::ptr;

use crate::has_direct_interior_mutability;
use crate::pivot::choose_pivot;
use crate::smallsort::SmallSortTypeImpl;

// Recursively select a pseudomedian if above this threshold.
const PSEUDO_MEDIAN_REC_THRESHOLD: usize = 64;

/// Sorts `v` recursively using quicksort.
///
/// `limit` when initialized with `c*log(v.len())` for some c ensures we do not
/// overflow the stack or go quadratic.
pub fn stable_quicksort<T, F>(
pub fn stable_quicksort<T, F: FnMut(&T, &T) -> bool>(
mut v: &mut [T],
scratch: &mut [MaybeUninit<T>],
mut limit: u32,
mut left_ancestor_pivot: Option<&T>,
is_less: &mut F,
) where
F: FnMut(&T, &T) -> bool,
{
) {
loop {
let len = v.len();

Expand Down Expand Up @@ -79,112 +75,18 @@ pub fn stable_quicksort<T, F>(
}
}

/// Selects a pivot from `v`. Algorithm taken from glidesort by Orson Peters.
///
/// This chooses a pivot by sampling an adaptive amount of points, approximating
/// the quality of a median of sqrt(n) elements.
fn choose_pivot<T, F>(v: &[T], is_less: &mut F) -> usize
where
F: FnMut(&T, &T) -> bool,
{
// We use unsafe code and raw pointers here because we're dealing with
// heavy recursion. Passing safe slices around would involve a lot of
// branches and function call overhead.

// SAFETY: a, b, c point to initialized regions of len_div_8 elements,
// satisfying median3 and median3_rec's preconditions as v_base points
// to an initialized region of n = len elements.
unsafe {
let v_base = v.as_ptr();
let len = v.len();
let len_div_8 = len / 8;

let a = v_base; // [0, floor(n/8))
let b = v_base.add(len_div_8 * 4); // [4*floor(n/8), 5*floor(n/8))
let c = v_base.add(len_div_8 * 7); // [7*floor(n/8), 8*floor(n/8))

if len < PSEUDO_MEDIAN_REC_THRESHOLD {
median3(&*a, &*b, &*c, is_less).sub_ptr(v_base)
} else {
median3_rec(a, b, c, len_div_8, is_less).sub_ptr(v_base)
}
}
}

/// Calculates an approximate median of 3 elements from sections a, b, c, or
/// recursively from an approximation of each, if they're large enough. By
/// dividing the size of each section by 8 when recursing we have logarithmic
/// recursion depth and overall sample from f(n) = 3*f(n/8) -> f(n) =
/// O(n^(log(3)/log(8))) ~= O(n^0.528) elements.
///
/// SAFETY: a, b, c must point to the start of initialized regions of memory of
/// at least n elements.
unsafe fn median3_rec<T, F>(
mut a: *const T,
mut b: *const T,
mut c: *const T,
n: usize,
is_less: &mut F,
) -> *const T
where
F: FnMut(&T, &T) -> bool,
{
// SAFETY: a, b, c still point to initialized regions of n / 8 elements,
// by the exact same logic as in choose_pivot.
unsafe {
if n * 8 >= PSEUDO_MEDIAN_REC_THRESHOLD {
let n8 = n / 8;
a = median3_rec(a, a.add(n8 * 4), a.add(n8 * 7), n8, is_less);
b = median3_rec(b, b.add(n8 * 4), b.add(n8 * 7), n8, is_less);
c = median3_rec(c, c.add(n8 * 4), c.add(n8 * 7), n8, is_less);
}
median3(&*a, &*b, &*c, is_less)
}
}

/// Calculates the median of 3 elements.
///
/// SAFETY: a, b, c must be valid initialized elements.
#[inline(always)]
fn median3<T, F>(a: &T, b: &T, c: &T, is_less: &mut F) -> *const T
where
F: FnMut(&T, &T) -> bool,
{
// Compiler tends to make this branchless when sensible, and avoids the
// third comparison when not.
let x = is_less(a, b);
let y = is_less(a, c);
if x == y {
// If x=y=0 then b, c <= a. In this case we want to return max(b, c).
// If x=y=1 then a < b, c. In this case we want to return min(b, c).
// By toggling the outcome of b < c using XOR x we get this behavior.
let z = is_less(b, c);
if z ^ x {
c
} else {
b
}
} else {
// Either c <= a < b or b <= a < c, thus a is our median.
a
}
}

/// Partitions `v` using pivot `p = v[pivot_pos]` and returns the number of
/// elements less than `p`. The relative order of elements that compare < p and
/// those that compare >= p is preserved - it is a stable partition.
///
/// If `is_less` is not a strict total order or panics, `scratch.len() < v.len()`,
/// or `pivot_pos >= v.len()`, the result and `v`'s state is sound but unspecified.
fn stable_partition<T, F>(
fn stable_partition<T, F: FnMut(&T, &T) -> bool>(
v: &mut [T],
scratch: &mut [MaybeUninit<T>],
pivot_pos: usize,
is_less: &mut F,
) -> usize
where
F: FnMut(&T, &T) -> bool,
{
) -> usize {
let num_lt = T::partition_fill_scratch(v, scratch, pivot_pos, is_less);

// SAFETY: partition_fill_scratch guarantees that scratch is initialized
Expand Down Expand Up @@ -213,27 +115,22 @@ trait StablePartitionTypeImpl: Sized {
/// Performs the same operation as [`stable_partition`], except it stores the
/// permuted elements as copies in `scratch`, with the >= partition in
/// reverse order.
fn partition_fill_scratch<F>(
fn partition_fill_scratch<F: FnMut(&Self, &Self) -> bool>(
v: &[Self],
scratch: &mut [MaybeUninit<Self>],
pivot_pos: usize,
is_less: &mut F,
) -> usize
where
F: FnMut(&Self, &Self) -> bool;
) -> usize;
}

impl<T> StablePartitionTypeImpl for T {
/// See [`StablePartitionTypeImpl::partition_fill_scratch`].
default fn partition_fill_scratch<F>(
default fn partition_fill_scratch<F: FnMut(&Self, &Self) -> bool>(
v: &[T],
scratch: &mut [MaybeUninit<T>],
pivot_pos: usize,
is_less: &mut F,
) -> usize
where
F: FnMut(&Self, &Self) -> bool,
{
) -> usize {
let len = v.len();
let v_base = v.as_ptr();
let scratch_base = MaybeUninit::slice_as_mut_ptr(scratch);
Expand Down Expand Up @@ -309,15 +206,12 @@ where
(): crate::IsTrue<{ mem::size_of::<T>() <= (mem::size_of::<u64>() * 2) }>,
{
/// See [`StablePartitionTypeImpl::partition_fill_scratch`].
fn partition_fill_scratch<F>(
fn partition_fill_scratch<F: FnMut(&Self, &Self) -> bool>(
v: &[T],
scratch: &mut [MaybeUninit<T>],
pivot_pos: usize,
is_less: &mut F,
) -> usize
where
F: FnMut(&Self, &Self) -> bool,
{
) -> usize {
let len = v.len();
let v_base = v.as_ptr();
let scratch_base = MaybeUninit::slice_as_mut_ptr(scratch);
Expand Down
Loading
Loading