Skip to content

Commit

Permalink
Generalize bidirectional merge in smallsort to non-even lengths
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp authored and Voultapher committed Mar 4, 2024
1 parent 079ce1a commit 9c2d26b
Showing 1 changed file with 94 additions and 87 deletions.
181 changes: 94 additions & 87 deletions src/smallsort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,83 +63,78 @@ fn sort_small_general<T: crate::Freeze, F: FnMut(&T, &T) -> bool>(
is_less: &mut F,
) {
let len = v.len();
if len < 2 {
return;
}

if len >= 2 {
if scratch.len() < MIN_SMALL_SORT_SCRATCH_LEN {
intrinsics::abort();
}
if scratch.len() < v.len() + 16 {
intrinsics::abort();
}

let v_base = v.as_mut_ptr();
if len < 8 {
insertion_sort_shift_left(v, 1, is_less);
return;
}

let offset = if len >= 8 {
let len_div_2 = len / 2;

// SAFETY: TODO
unsafe {
let scratch_base = scratch.as_mut_ptr() as *mut T;

let presorted_len = if len >= 16 {
// SAFETY: scratch_base is valid and has enough space.
sort8_stable(
v_base,
scratch_base.add(T::SMALL_SORT_THRESHOLD),
scratch_base,
is_less,
);

sort8_stable(
v_base.add(len_div_2),
scratch_base.add(T::SMALL_SORT_THRESHOLD + 8),
scratch_base.add(len_div_2),
is_less,
);

8
} else {
// SAFETY: scratch_base is valid and has enough space.
sort4_stable(v_base, scratch_base, is_less);
sort4_stable(v_base.add(len_div_2), scratch_base.add(len_div_2), is_less);

4
};

for offset in [0, len_div_2] {
let src = scratch_base.add(offset);
let dst = v_base.add(offset);

for i in presorted_len..len_div_2 {
ptr::copy_nonoverlapping(dst.add(i), src.add(i), 1);
insert_tail(src, src.add(i), is_less);
}
}

let even_len = len - (len % 2);

// SAFETY: scratch_base is initialized with even_len elements,
// and v_base is large enough to copy to.
let drop_guard = CopyOnDrop {
src: scratch_base,
dst: v_base,
len: even_len,
};

// It's faster to merge directly into `v` and copy over the 'safe' elements of
// `scratch` into v only if there was a panic. This technique is similar to
// ping-pong merging.
bi_directional_merge_even(
&*ptr::slice_from_raw_parts(drop_guard.src, drop_guard.len),
drop_guard.dst,
is_less,
);
mem::forget(drop_guard);

even_len
}
let v_base = v.as_mut_ptr();
let len_div_2 = len / 2;

unsafe {
let scratch_base = scratch.as_mut_ptr() as *mut T;

let presorted_len = if len >= 16 {
// SAFETY: scratch_base is valid and has enough space.
sort8_stable(v_base, scratch_base, scratch_base.add(len), is_less);

sort8_stable(
v_base.add(len_div_2),
scratch_base.add(len_div_2),
scratch_base.add(len + 8),
is_less,
);

8
} else {
1
// SAFETY: scratch_base is valid and has enough space.
sort4_stable(v_base, scratch_base, is_less);
sort4_stable(v_base.add(len_div_2), scratch_base.add(len_div_2), is_less);

4
};

insertion_sort_shift_left(v, offset, is_less);
for offset in [0, len_div_2] {
// SAFETY: at this point dst is initialized with presorted_len elements.
// We extend this to desired_len, src is valid for desired_len elements.
let src = v_base.add(offset);
let dst = scratch_base.add(offset);
let desired_len = if offset == 0 {
len_div_2
} else {
len - len_div_2
};

for i in presorted_len..desired_len {
ptr::copy_nonoverlapping(src.add(i), dst.add(i), 1);
insert_tail(dst, dst.add(i), is_less);
}
}

// See comment in `CopyOnDrop::drop`.
let drop_guard = CopyOnDrop {
src: scratch_base,
dst: v_base,
len,
};

// It's faster to merge directly into `v` and copy over the 'safe' elements of
// `scratch` into v only if there was a panic. This technique is similar to
// ping-pong merging.
bidirectional_merge(
&*ptr::slice_from_raw_parts(drop_guard.src, drop_guard.len),
drop_guard.dst,
is_less,
);
mem::forget(drop_guard);
}
}

Expand Down Expand Up @@ -297,8 +292,8 @@ pub unsafe fn sort4_stable<T, F: FnMut(&T, &T) -> bool>(
#[inline(never)]
unsafe fn sort8_stable<T: crate::Freeze, F: FnMut(&T, &T) -> bool>(
v_base: *mut T,
scratch_base: *mut T,
dst: *mut T,
scratch_base: *mut T,
is_less: &mut F,
) {
// SAFETY: The caller must guarantee that scratch_base is valid for 8 writes, and that v_base is
Expand All @@ -310,7 +305,7 @@ unsafe fn sort8_stable<T: crate::Freeze, F: FnMut(&T, &T) -> bool>(

// SAFETY: TODO
unsafe {
bi_directional_merge_even(&*ptr::slice_from_raw_parts(scratch_base, 8), dst, is_less);
bidirectional_merge(&*ptr::slice_from_raw_parts(scratch_base, 8), dst, is_less);
}
}

Expand All @@ -326,10 +321,10 @@ unsafe fn merge_up<T, F: FnMut(&T, &T) -> bool>(
//
// if !is_less(&*right_src, &*left_src) {
// ptr::copy_nonoverlapping(left_src, dst, 1);
// left_src = left_src.wrapping_add(1);
// left_src = left_src.add(1);
// } else {
// ptr::copy_nonoverlapping(right_src, dst, 1);
// right_src = right_src.wrapping_add(1);
// right_src = right_src.add(1);
// }
// dst = dst.add(1);

Expand All @@ -339,8 +334,8 @@ unsafe fn merge_up<T, F: FnMut(&T, &T) -> bool>(
let is_l = !is_less(&*right_src, &*left_src);
let src = if is_l { left_src } else { right_src };
ptr::copy_nonoverlapping(src, dst, 1);
right_src = right_src.wrapping_add(!is_l as usize);
left_src = left_src.wrapping_add(is_l as usize);
right_src = right_src.add(!is_l as usize);
left_src = left_src.add(is_l as usize);
dst = dst.add(1);
}

Expand Down Expand Up @@ -380,15 +375,15 @@ unsafe fn merge_down<T, F: FnMut(&T, &T) -> bool>(
(left_src, right_src, dst)
}

/// Merge v assuming the len is even and v[..len / 2] and v[len / 2..] are sorted.
/// Merge v assuming v[..len / 2] and v[len / 2..] are sorted.
///
/// Original idea for bi-directional merging by Igor van den Hoven (quadsort), adapted to only use
/// merge up and down. In contrast to the original parity_merge function, it performs 2 writes
/// instead of 4 per iteration. Ord violation detection was added.
///
// SAFETY: the caller must guarantee that `dst` is valid for v.len() writes.
// Also `v.as_ptr` and `dst` must not alias.
unsafe fn bi_directional_merge_even<T: crate::Freeze, F: FnMut(&T, &T) -> bool>(
// Also `v.as_ptr` and `dst` must not alias. v.len() must be >= 2.
unsafe fn bidirectional_merge<T: crate::Freeze, F: FnMut(&T, &T) -> bool>(
v: &[T],
dst: *mut T,
is_less: &mut F,
Expand Down Expand Up @@ -435,22 +430,34 @@ unsafe fn bi_directional_merge_even<T: crate::Freeze, F: FnMut(&T, &T) -> bool>(
// bounds if the caller guarantees that `dst` is valid for `v.len()` writes.
unsafe {
let mut left = src;
let mut right = src.wrapping_add(len_div_2);
let mut right = src.add(len_div_2);
let mut dst = dst;

let mut left_rev = src.wrapping_add(len_div_2 - 1);
let mut right_rev = src.wrapping_add(len - 1);
let mut dst_rev = dst.wrapping_add(len - 1);
let mut left_rev = src.add(len_div_2 - 1);
let mut right_rev = src.add(len - 1);
let mut dst_rev = dst.add(len - 1);

for _ in 0..len_div_2 {
(left, right, dst) = merge_up(left, right, dst, is_less);
(left_rev, right_rev, dst_rev) = merge_down(left_rev, right_rev, dst_rev, is_less);
}

let left_diff = (left as usize).wrapping_sub(left_rev as usize);
let right_diff = (right as usize).wrapping_sub(right_rev as usize);
let left_end = left_rev.wrapping_add(1);
let right_end = right_rev.wrapping_add(1);

// Odd length, so one element is left unconsumed in the input.
if len % 2 != 0 {
let left_nonempty = left < left_end;
let last_src = if left_nonempty { left } else { right };
ptr::copy_nonoverlapping(last_src, dst, 1);
left = left.add(left_nonempty as usize);
right = right.add((!left_nonempty) as usize);
}

if !(left_diff == mem::size_of::<T>() && right_diff == mem::size_of::<T>()) {
// We now should have consumed the full input exactly once. This can
// only fail if the comparison operator fails to be Ord, in which case
// we will panic and never access the inconsistent state in dst.
if left != left_end || right != right_end {
panic_on_ord_violation();
}
}
Expand Down

0 comments on commit 9c2d26b

Please sign in to comment.