From 9c2d26b4b3811764185a986d666c9e5b8723a232 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Wed, 21 Feb 2024 00:23:29 +0100 Subject: [PATCH] Generalize bidirectional merge in smallsort to non-even lengths --- src/smallsort.rs | 181 ++++++++++++++++++++++++----------------------- 1 file changed, 94 insertions(+), 87 deletions(-) diff --git a/src/smallsort.rs b/src/smallsort.rs index 7f93014..97965ea 100644 --- a/src/smallsort.rs +++ b/src/smallsort.rs @@ -63,83 +63,78 @@ fn sort_small_general bool>( is_less: &mut F, ) { let len = v.len(); + if len < 2 { + return; + } - if len >= 2 { - if scratch.len() < MIN_SMALL_SORT_SCRATCH_LEN { - intrinsics::abort(); - } + if scratch.len() < v.len() + 16 { + intrinsics::abort(); + } - let v_base = v.as_mut_ptr(); + if len < 8 { + insertion_sort_shift_left(v, 1, is_less); + return; + } - let offset = if len >= 8 { - let len_div_2 = len / 2; - - // SAFETY: TODO - unsafe { - let scratch_base = scratch.as_mut_ptr() as *mut T; - - let presorted_len = if len >= 16 { - // SAFETY: scratch_base is valid and has enough space. - sort8_stable( - v_base, - scratch_base.add(T::SMALL_SORT_THRESHOLD), - scratch_base, - is_less, - ); - - sort8_stable( - v_base.add(len_div_2), - scratch_base.add(T::SMALL_SORT_THRESHOLD + 8), - scratch_base.add(len_div_2), - is_less, - ); - - 8 - } else { - // SAFETY: scratch_base is valid and has enough space. - sort4_stable(v_base, scratch_base, is_less); - sort4_stable(v_base.add(len_div_2), scratch_base.add(len_div_2), is_less); - - 4 - }; - - for offset in [0, len_div_2] { - let src = scratch_base.add(offset); - let dst = v_base.add(offset); - - for i in presorted_len..len_div_2 { - ptr::copy_nonoverlapping(dst.add(i), src.add(i), 1); - insert_tail(src, src.add(i), is_less); - } - } - - let even_len = len - (len % 2); - - // SAFETY: scratch_base is initialized with even_len elements, - // and v_base is large enough to copy to. - let drop_guard = CopyOnDrop { - src: scratch_base, - dst: v_base, - len: even_len, - }; - - // It's faster to merge directly into `v` and copy over the 'safe' elements of - // `scratch` into v only if there was a panic. This technique is similar to - // ping-pong merging. - bi_directional_merge_even( - &*ptr::slice_from_raw_parts(drop_guard.src, drop_guard.len), - drop_guard.dst, - is_less, - ); - mem::forget(drop_guard); - - even_len - } + let v_base = v.as_mut_ptr(); + let len_div_2 = len / 2; + + unsafe { + let scratch_base = scratch.as_mut_ptr() as *mut T; + + let presorted_len = if len >= 16 { + // SAFETY: scratch_base is valid and has enough space. + sort8_stable(v_base, scratch_base, scratch_base.add(len), is_less); + + sort8_stable( + v_base.add(len_div_2), + scratch_base.add(len_div_2), + scratch_base.add(len + 8), + is_less, + ); + + 8 } else { - 1 + // SAFETY: scratch_base is valid and has enough space. + sort4_stable(v_base, scratch_base, is_less); + sort4_stable(v_base.add(len_div_2), scratch_base.add(len_div_2), is_less); + + 4 }; - insertion_sort_shift_left(v, offset, is_less); + for offset in [0, len_div_2] { + // SAFETY: at this point dst is initialized with presorted_len elements. + // We extend this to desired_len, src is valid for desired_len elements. + let src = v_base.add(offset); + let dst = scratch_base.add(offset); + let desired_len = if offset == 0 { + len_div_2 + } else { + len - len_div_2 + }; + + for i in presorted_len..desired_len { + ptr::copy_nonoverlapping(src.add(i), dst.add(i), 1); + insert_tail(dst, dst.add(i), is_less); + } + } + + // See comment in `CopyOnDrop::drop`. + let drop_guard = CopyOnDrop { + src: scratch_base, + dst: v_base, + len, + }; + + // It's faster to merge directly into `v` and copy over the 'safe' elements of + // `scratch` into v only if there was a panic. This technique is similar to + // ping-pong merging. + bidirectional_merge( + &*ptr::slice_from_raw_parts(drop_guard.src, drop_guard.len), + drop_guard.dst, + is_less, + ); + mem::forget(drop_guard); } } @@ -297,8 +292,8 @@ pub unsafe fn sort4_stable bool>( #[inline(never)] unsafe fn sort8_stable bool>( v_base: *mut T, - scratch_base: *mut T, dst: *mut T, + scratch_base: *mut T, is_less: &mut F, ) { // SAFETY: The caller must guarantee that scratch_base is valid for 8 writes, and that v_base is @@ -310,7 +305,7 @@ unsafe fn sort8_stable bool>( // SAFETY: TODO unsafe { - bi_directional_merge_even(&*ptr::slice_from_raw_parts(scratch_base, 8), dst, is_less); + bidirectional_merge(&*ptr::slice_from_raw_parts(scratch_base, 8), dst, is_less); } } @@ -326,10 +321,10 @@ unsafe fn merge_up bool>( // // if !is_less(&*right_src, &*left_src) { // ptr::copy_nonoverlapping(left_src, dst, 1); - // left_src = left_src.wrapping_add(1); + // left_src = left_src.add(1); // } else { // ptr::copy_nonoverlapping(right_src, dst, 1); - // right_src = right_src.wrapping_add(1); + // right_src = right_src.add(1); // } // dst = dst.add(1); @@ -339,8 +334,8 @@ unsafe fn merge_up bool>( let is_l = !is_less(&*right_src, &*left_src); let src = if is_l { left_src } else { right_src }; ptr::copy_nonoverlapping(src, dst, 1); - right_src = right_src.wrapping_add(!is_l as usize); - left_src = left_src.wrapping_add(is_l as usize); + right_src = right_src.add(!is_l as usize); + left_src = left_src.add(is_l as usize); dst = dst.add(1); } @@ -380,15 +375,15 @@ unsafe fn merge_down bool>( (left_src, right_src, dst) } -/// Merge v assuming the len is even and v[..len / 2] and v[len / 2..] are sorted. +/// Merge v assuming v[..len / 2] and v[len / 2..] are sorted. /// /// Original idea for bi-directional merging by Igor van den Hoven (quadsort), adapted to only use /// merge up and down. In contrast to the original parity_merge function, it performs 2 writes /// instead of 4 per iteration. Ord violation detection was added. /// // SAFETY: the caller must guarantee that `dst` is valid for v.len() writes. -// Also `v.as_ptr` and `dst` must not alias. -unsafe fn bi_directional_merge_even bool>( +// Also `v.as_ptr` and `dst` must not alias. v.len() must be >= 2. +unsafe fn bidirectional_merge bool>( v: &[T], dst: *mut T, is_less: &mut F, @@ -435,22 +430,34 @@ unsafe fn bi_directional_merge_even bool>( // bounds if the caller guarantees that `dst` is valid for `v.len()` writes. unsafe { let mut left = src; - let mut right = src.wrapping_add(len_div_2); + let mut right = src.add(len_div_2); let mut dst = dst; - let mut left_rev = src.wrapping_add(len_div_2 - 1); - let mut right_rev = src.wrapping_add(len - 1); - let mut dst_rev = dst.wrapping_add(len - 1); + let mut left_rev = src.add(len_div_2 - 1); + let mut right_rev = src.add(len - 1); + let mut dst_rev = dst.add(len - 1); for _ in 0..len_div_2 { (left, right, dst) = merge_up(left, right, dst, is_less); (left_rev, right_rev, dst_rev) = merge_down(left_rev, right_rev, dst_rev, is_less); } - let left_diff = (left as usize).wrapping_sub(left_rev as usize); - let right_diff = (right as usize).wrapping_sub(right_rev as usize); + let left_end = left_rev.wrapping_add(1); + let right_end = right_rev.wrapping_add(1); + + // Odd length, so one element is left unconsumed in the input. + if len % 2 != 0 { + let left_nonempty = left < left_end; + let last_src = if left_nonempty { left } else { right }; + ptr::copy_nonoverlapping(last_src, dst, 1); + left = left.add(left_nonempty as usize); + right = right.add((!left_nonempty) as usize); + } - if !(left_diff == mem::size_of::() && right_diff == mem::size_of::()) { + // We now should have consumed the full input exactly once. This can + // only fail if the comparison operator fails to be Ord, in which case + // we will panic and never access the inconsistent state in dst. + if left != left_end || right != right_end { panic_on_ord_violation(); } }