Skip to content

Commit

Permalink
Apply review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Voultapher committed Mar 3, 2024
1 parent aa0c0f5 commit e8ea138
Showing 1 changed file with 77 additions and 56 deletions.
133 changes: 77 additions & 56 deletions src/quicksort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,6 @@ pub fn stable_quicksort<T, F: FnMut(&T, &T) -> bool>(
}
}

struct PartitionState<T> {
// The current element that is being looked at, scans left to right through slice.
scan: *const T,
// Counts the number of elements that compared less-than, also works around:
// https://github.com/rust-lang/rust/issues/117128
num_lt: usize,
// Reverse scratch output pointer.
scratch_rev: *mut T,
}

/// Partitions `v` using pivot `p = v[pivot_pos]` and returns the number of
/// elements less than `p`. The relative order of elements that compare < p and
/// those that compare >= p is preserved - it is a stable partition.
Expand All @@ -109,7 +99,10 @@ fn stable_partition<T, F: FnMut(&T, &T) -> bool>(
let v_base = v.as_ptr();
let scratch_base = MaybeUninit::slice_as_mut_ptr(scratch);

// TODO explain logic on high level.
// The core idea is to write the values that compare as less-than to the left side of `scratch`,
// while the values that compared as greater or equal than `v[pivot_pos]` go to the right side
// of `scratch` in reverse. Most of the inner complexity stems from avoiding self-comparisons
// with pivot and delayed pivot hole filling because of non `Freeze` types.

// Regarding auto unrolling and manual unrolling. Auto unrolling as tested with rustc 1.75 is
// somewhat run-time and binary-size inefficient, because it performs additional math to
Expand All @@ -121,38 +114,32 @@ fn stable_partition<T, F: FnMut(&T, &T) -> bool>(
// component of the sort implementation.

// SAFETY: we checked that pivot_pos is in-bounds above, and that scratch has length at least
// len. As we do binary classification into lt or ge, the invariant num_lt + num_ge = i always
// holds at the start of each iteration. For micro-optimization reasons we write i - num_lt
// instead of num_gt. Since num_lt increases by at most 1 each iteration and since i < len, this
// proves our destination indices num_lt and len - 1 - num_ge stay in-bounds, and are never
// equal except the final iteration when num_lt = len - 1 - (len - 1 - num_lt) = len - 1 -
// num_ge. We write one different element to scratch each iteration thus scratch[0..len] will be
// initialized with a permutation of v. TODO extend with nested loop logic.
// len. As we do binary classification into lt or ge, the invariant num_left + num_right = i
// always holds at the start of each iteration. For micro-optimization reasons we write i -
// num_left instead of num_right. Since num_left increases by at most 1 each iteration and since
// i < len, this proves our destination indices num_left and len - 1 - num_right stay in-bounds,
// and are never equal except the final iteration when num_left = len - 1 - (len - 1 - num_left)
// = len - 1 - num_right. We write one element to scratch each iteration thus scratch[0..len]
// will be initialized with a permutation of v. The body of `loop` has nearly the same semantics
// as:
// ```
// for 0..len {
// state.partition_one(is_less(&*state.scan, &*pivot));
// }
// ```
// Where we treat `state.scan == pivot` specially to avoid calling is_less with the same value.
// self comparison is not directly UB or problematic in and by itself, but its possible that
// user logic depends on this not occurring. E.g. where the comparison function takes a lock,
// which would deadlock.
unsafe {
// SAFETY: exactly the same invariants and logic as the non-specialized impl. And we do
// naive loop unrolling where the exact same loop body is just repeated.
let pivot = v_base.add(pivot_pos);

let mut loop_body = |state: &mut PartitionState<T>| {
// println!("state.scan: {}", state.scan.sub_ptr(v_base));
state.scratch_rev = state.scratch_rev.sub(1);

// println!("loop_body state.scan: {:?}", *(state.scan as *const DebugT));
let is_less_than_pivot = is_less(&*state.scan, &*pivot);
let dst_base = if is_less_than_pivot {
scratch_base // i + num_lt
} else {
state.scratch_rev // len - (i + 1) + num_lt = len - 1 - num_ge
};
ptr::copy_nonoverlapping(state.scan, dst_base.add(state.num_lt), 1);

state.num_lt += is_less_than_pivot as usize;
state.scan = state.scan.add(1);
};

let mut state = PartitionState {
scratch_base,
scan: v_base,
num_lt: 0,
num_left: 0,
scratch_rev: scratch_base.add(len),
};

Expand All @@ -165,16 +152,16 @@ fn stable_partition<T, F: FnMut(&T, &T) -> bool>(
const UNROLL_LEN: usize = 4;
let unroll_end = v_base.add(loop_end_pos.saturating_sub(UNROLL_LEN - 1));
while state.scan < unroll_end {
loop_body(&mut state);
loop_body(&mut state);
loop_body(&mut state);
loop_body(&mut state);
state.partition_one(is_less(&*state.scan, &*pivot));
state.partition_one(is_less(&*state.scan, &*pivot));
state.partition_one(is_less(&*state.scan, &*pivot));
state.partition_one(is_less(&*state.scan, &*pivot));
}
}

let loop_end = v_base.add(loop_end_pos);
while state.scan < loop_end {
loop_body(&mut state);
state.partition_one(is_less(&*state.scan, &*pivot));
}

if loop_end_pos == len {
Expand All @@ -183,15 +170,7 @@ fn stable_partition<T, F: FnMut(&T, &T) -> bool>(

// Handle pivot, doing it this way neatly handles type with interior mutability and
// avoids self comparison as well as a branch in the inner partition loop.
state.scratch_rev = state.scratch_rev.sub(1);
let pivot_dst_base = if pivot_goes_left {
scratch_base
} else {
state.scratch_rev
};
pivot_in_scratch = pivot_dst_base.add(state.num_lt);
state.num_lt += pivot_goes_left as usize;
state.scan = state.scan.add(1);
pivot_in_scratch = state.partition_one(pivot_goes_left);

loop_end_pos = len;
}
Expand All @@ -200,22 +179,64 @@ fn stable_partition<T, F: FnMut(&T, &T) -> bool>(
ptr::copy_nonoverlapping(pivot, pivot_in_scratch, 1);

// SAFETY: partition_fill_scratch guarantees that scratch is initialized with a permuted
// copy of `v`, and that num_lt <= v.len(). Copying scratch[0..num_lt] and
// scratch[num_lt..v.len()] back is thus sound, as the values in scratch will never be read
// copy of `v`, and that num_left <= v.len(). Copying scratch[0..num_left] and
// scratch[num_left..v.len()] back is thus sound, as the values in scratch will never be read
// again, meaning our copies semantically act as moves, permuting `v`. Copy all the elements
// < p directly from swap to v.
let v_base = v.as_mut_ptr();
ptr::copy_nonoverlapping(scratch_base, v_base, state.num_lt);
ptr::copy_nonoverlapping(scratch_base, v_base, state.num_left);

// Copy the elements >= p in reverse order.
for i in 0..len - state.num_lt {
for i in 0..len - state.num_left {
ptr::copy_nonoverlapping(
scratch_base.add(len - 1 - i),
v_base.add(state.num_lt + i),
v_base.add(state.num_left + i),
1,
);
}

state.num_lt
state.num_left
}
}

struct PartitionState<T> {
// The start of the scratch auxiliary memory.
scratch_base: *mut T,
// The current element that is being looked at, scans left to right through slice.
scan: *const T,
// Counts the number of elements that went to the left side, also works around:
// https://github.com/rust-lang/rust/issues/117128
num_left: usize,
// Reverse scratch output pointer.
scratch_rev: *mut T,
}

impl<T> PartitionState<T> {
/// Depending on the value of `towards_left` will write a value to the growing left or right
/// side of the scratch memory. Track state accordingly. This forms the branchless core of the
/// partition.
///
/// SAFETY: The caller must ensure that `PartitionState` is initialized correctly, where
/// `scratch_base` points to a contiguous area of length `len` memory that is valid for writing.
/// `scan` must point initially point to a contiguous area of `len` values that are valid to be
/// read. In addition this function MUST be called exactly `len` times, otherwise the values
/// written to the `scratch_base` region must considered incomplete and not read again.
unsafe fn partition_one(&mut self, towards_left: bool) -> *mut T {
// SAFETY: See function safety comment.
unsafe {
self.scratch_rev = self.scratch_rev.sub(1);

let dst_base = if towards_left {
self.scratch_base // i + num_left
} else {
self.scratch_rev // len - (i + 1) + num_left = len - 1 - num_right
};
let dst = dst_base.add(self.num_left);
ptr::copy_nonoverlapping(self.scan, dst, 1);

self.num_left += towards_left as usize;
self.scan = self.scan.add(1);
dst
}
}
}

0 comments on commit e8ea138

Please sign in to comment.