diff --git a/src/quicksort.rs b/src/quicksort.rs index 2b25414..bee8435 100644 --- a/src/quicksort.rs +++ b/src/quicksort.rs @@ -77,16 +77,6 @@ pub fn stable_quicksort bool>( } } -struct PartitionState { - // The current element that is being looked at, scans left to right through slice. - scan: *const T, - // Counts the number of elements that compared less-than, also works around: - // https://github.com/rust-lang/rust/issues/117128 - num_lt: usize, - // Reverse scratch output pointer. - scratch_rev: *mut T, -} - /// Partitions `v` using pivot `p = v[pivot_pos]` and returns the number of /// elements less than `p`. The relative order of elements that compare < p and /// those that compare >= p is preserved - it is a stable partition. @@ -109,7 +99,10 @@ fn stable_partition bool>( let v_base = v.as_ptr(); let scratch_base = MaybeUninit::slice_as_mut_ptr(scratch); - // TODO explain logic on high level. + // The core idea is to write the values that compare as less-than to the left side of `scratch`, + // while the values that compared as greater or equal than `v[pivot_pos]` go to the right side + // of `scratch` in reverse. Most of the inner complexity stems from avoiding self-comparisons + // with pivot and delayed pivot hole filling because of non `Freeze` types. // Regarding auto unrolling and manual unrolling. Auto unrolling as tested with rustc 1.75 is // somewhat run-time and binary-size inefficient, because it performs additional math to @@ -121,38 +114,32 @@ fn stable_partition bool>( // component of the sort implementation. // SAFETY: we checked that pivot_pos is in-bounds above, and that scratch has length at least - // len. As we do binary classification into lt or ge, the invariant num_lt + num_ge = i always - // holds at the start of each iteration. For micro-optimization reasons we write i - num_lt - // instead of num_gt. Since num_lt increases by at most 1 each iteration and since i < len, this - // proves our destination indices num_lt and len - 1 - num_ge stay in-bounds, and are never - // equal except the final iteration when num_lt = len - 1 - (len - 1 - num_lt) = len - 1 - - // num_ge. We write one different element to scratch each iteration thus scratch[0..len] will be - // initialized with a permutation of v. TODO extend with nested loop logic. + // len. As we do binary classification into lt or ge, the invariant num_left + num_right = i + // always holds at the start of each iteration. For micro-optimization reasons we write i - + // num_left instead of num_right. Since num_left increases by at most 1 each iteration and since + // i < len, this proves our destination indices num_left and len - 1 - num_right stay in-bounds, + // and are never equal except the final iteration when num_left = len - 1 - (len - 1 - num_left) + // = len - 1 - num_right. We write one element to scratch each iteration thus scratch[0..len] + // will be initialized with a permutation of v. The body of `loop` has nearly the same semantics + // as: + // ``` + // for 0..len { + // state.partition_one(is_less(&*state.scan, &*pivot)); + // } + // ``` + // Where we treat `state.scan == pivot` specially to avoid calling is_less with the same value. + // self comparison is not directly UB or problematic in and by itself, but its possible that + // user logic depends on this not occurring. E.g. where the comparison function takes a lock, + // which would deadlock. unsafe { // SAFETY: exactly the same invariants and logic as the non-specialized impl. And we do // naive loop unrolling where the exact same loop body is just repeated. let pivot = v_base.add(pivot_pos); - let mut loop_body = |state: &mut PartitionState| { - // println!("state.scan: {}", state.scan.sub_ptr(v_base)); - state.scratch_rev = state.scratch_rev.sub(1); - - // println!("loop_body state.scan: {:?}", *(state.scan as *const DebugT)); - let is_less_than_pivot = is_less(&*state.scan, &*pivot); - let dst_base = if is_less_than_pivot { - scratch_base // i + num_lt - } else { - state.scratch_rev // len - (i + 1) + num_lt = len - 1 - num_ge - }; - ptr::copy_nonoverlapping(state.scan, dst_base.add(state.num_lt), 1); - - state.num_lt += is_less_than_pivot as usize; - state.scan = state.scan.add(1); - }; - let mut state = PartitionState { + scratch_base, scan: v_base, - num_lt: 0, + num_left: 0, scratch_rev: scratch_base.add(len), }; @@ -165,16 +152,16 @@ fn stable_partition bool>( const UNROLL_LEN: usize = 4; let unroll_end = v_base.add(loop_end_pos.saturating_sub(UNROLL_LEN - 1)); while state.scan < unroll_end { - loop_body(&mut state); - loop_body(&mut state); - loop_body(&mut state); - loop_body(&mut state); + state.partition_one(is_less(&*state.scan, &*pivot)); + state.partition_one(is_less(&*state.scan, &*pivot)); + state.partition_one(is_less(&*state.scan, &*pivot)); + state.partition_one(is_less(&*state.scan, &*pivot)); } } let loop_end = v_base.add(loop_end_pos); while state.scan < loop_end { - loop_body(&mut state); + state.partition_one(is_less(&*state.scan, &*pivot)); } if loop_end_pos == len { @@ -183,15 +170,7 @@ fn stable_partition bool>( // Handle pivot, doing it this way neatly handles type with interior mutability and // avoids self comparison as well as a branch in the inner partition loop. - state.scratch_rev = state.scratch_rev.sub(1); - let pivot_dst_base = if pivot_goes_left { - scratch_base - } else { - state.scratch_rev - }; - pivot_in_scratch = pivot_dst_base.add(state.num_lt); - state.num_lt += pivot_goes_left as usize; - state.scan = state.scan.add(1); + pivot_in_scratch = state.partition_one(pivot_goes_left); loop_end_pos = len; } @@ -200,22 +179,64 @@ fn stable_partition bool>( ptr::copy_nonoverlapping(pivot, pivot_in_scratch, 1); // SAFETY: partition_fill_scratch guarantees that scratch is initialized with a permuted - // copy of `v`, and that num_lt <= v.len(). Copying scratch[0..num_lt] and - // scratch[num_lt..v.len()] back is thus sound, as the values in scratch will never be read + // copy of `v`, and that num_left <= v.len(). Copying scratch[0..num_left] and + // scratch[num_left..v.len()] back is thus sound, as the values in scratch will never be read // again, meaning our copies semantically act as moves, permuting `v`. Copy all the elements // < p directly from swap to v. let v_base = v.as_mut_ptr(); - ptr::copy_nonoverlapping(scratch_base, v_base, state.num_lt); + ptr::copy_nonoverlapping(scratch_base, v_base, state.num_left); // Copy the elements >= p in reverse order. - for i in 0..len - state.num_lt { + for i in 0..len - state.num_left { ptr::copy_nonoverlapping( scratch_base.add(len - 1 - i), - v_base.add(state.num_lt + i), + v_base.add(state.num_left + i), 1, ); } - state.num_lt + state.num_left + } +} + +struct PartitionState { + // The start of the scratch auxiliary memory. + scratch_base: *mut T, + // The current element that is being looked at, scans left to right through slice. + scan: *const T, + // Counts the number of elements that went to the left side, also works around: + // https://github.com/rust-lang/rust/issues/117128 + num_left: usize, + // Reverse scratch output pointer. + scratch_rev: *mut T, +} + +impl PartitionState { + /// Depending on the value of `towards_left` will write a value to the growing left or right + /// side of the scratch memory. Track state accordingly. This forms the branchless core of the + /// partition. + /// + /// SAFETY: The caller must ensure that `PartitionState` is initialized correctly, where + /// `scratch_base` points to a contiguous area of length `len` memory that is valid for writing. + /// `scan` must point initially point to a contiguous area of `len` values that are valid to be + /// read. In addition this function MUST be called exactly `len` times, otherwise the values + /// written to the `scratch_base` region must considered incomplete and not read again. + unsafe fn partition_one(&mut self, towards_left: bool) -> *mut T { + // SAFETY: See function safety comment. + unsafe { + self.scratch_rev = self.scratch_rev.sub(1); + + let dst_base = if towards_left { + self.scratch_base // i + num_left + } else { + self.scratch_rev // len - (i + 1) + num_left = len - 1 - num_right + }; + let dst = dst_base.add(self.num_left); + ptr::copy_nonoverlapping(self.scan, dst, 1); + + self.num_left += towards_left as usize; + self.scan = self.scan.add(1); + dst + } } }