Skip to content

Commit

Permalink
Apply review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Voultapher committed Jul 29, 2023
1 parent d5f664d commit a2bc908
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 79 deletions.
14 changes: 7 additions & 7 deletions src/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ where
hole = MergeHole {
start: buf,
end: buf.add(mid),
dest: v,
dst: v,
};
}

// Initially, these pointers point to the beginnings of their arrays.
let left = &mut hole.start;
let mut right = v_mid;
let out = &mut hole.dest;
let out = &mut hole.dst;

while *left < hole.end && right < v_end {
// Consume the lesser side.
Expand All @@ -83,12 +83,12 @@ where
hole = MergeHole {
start: buf,
end: buf.add(len - mid),
dest: v_mid,
dst: v_mid,
};
}

// Initially, these pointers point past the ends of their arrays.
let left = &mut hole.dest;
let left = &mut hole.dst;
let right = &mut hole.end;
let mut out = v_end;

Expand Down Expand Up @@ -124,19 +124,19 @@ where
*ptr
}

// When dropped, copies the range `start..end` into `dest..`.
// When dropped, copies the range `start..end` into `dst..`.
struct MergeHole<T> {
start: *mut T,
end: *mut T,
dest: *mut T,
dst: *mut T,
}

impl<T> Drop for MergeHole<T> {
fn drop(&mut self) {
// SAFETY: `T` is not a zero-sized type, and these are pointers into a slice's elements.
unsafe {
let len = self.end.sub_ptr(self.start);
ptr::copy_nonoverlapping(self.start, self.dest, len);
ptr::copy_nonoverlapping(self.start, self.dst, len);
}
}
}
Expand Down
144 changes: 72 additions & 72 deletions src/smallsort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ where
T::small_sort(v, is_less);
}

struct GapGuardNonoverlapping<T> {
struct GapGuard<T> {
pos: *mut T,
value: ManuallyDrop<T>,
}

impl<T> Drop for GapGuardNonoverlapping<T> {
impl<T> Drop for GapGuard<T> {
fn drop(&mut self) {
unsafe {
ptr::write(self.pos, ManuallyDrop::take(&mut self.value));
ptr::copy_nonoverlapping(&*self.value, self.pos, 1);
}
}
}
Expand Down Expand Up @@ -67,7 +67,7 @@ where
// If `is_less` panics at any point during the process, `gap` will get dropped and
// fill the gap in `v` with `tmp`, thus ensuring that `v` still holds every object it
// initially held exactly once.
let mut gap = GapGuardNonoverlapping {
let mut gap = GapGuard {
pos: v_i.sub(1),
value: mem::ManuallyDrop::new(ptr::read(v_i)),
};
Expand Down Expand Up @@ -187,7 +187,7 @@ impl<T: crate::Freeze> SmallSortTypeImpl for T {

/// SAFETY: The caller MUST guarantee that `v_base` is valid for 4 reads and `dest_ptr` is valid
/// for 4 writes.
pub unsafe fn sort4_stable<T, F>(v_base: *const T, out: *mut T, is_less: &mut F)
pub unsafe fn sort4_stable<T, F>(v_base: *const T, dst: *mut T, is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
Expand Down Expand Up @@ -227,10 +227,10 @@ where
let lo = select(c5, unknown_right, unknown_left);
let hi = select(c5, unknown_left, unknown_right);

ptr::copy_nonoverlapping(min, out, 1);
ptr::copy_nonoverlapping(lo, out.add(1), 1);
ptr::copy_nonoverlapping(hi, out.add(2), 1);
ptr::copy_nonoverlapping(max, out.add(3), 1);
ptr::copy_nonoverlapping(min, dst, 1);
ptr::copy_nonoverlapping(lo, dst.add(1), 1);
ptr::copy_nonoverlapping(hi, dst.add(2), 1);
ptr::copy_nonoverlapping(max, dst.add(3), 1);
}

#[inline(always)]
Expand Down Expand Up @@ -266,7 +266,7 @@ where
// into v only if there was a panic. This technique is also known as ping-pong merge.
let drop_guard = DropGuard {
src: scratch_base,
dest: v_base,
dst: v_base,
};
bi_directional_merge_even(
&*ptr::slice_from_raw_parts(scratch_base, 8),
Expand All @@ -278,25 +278,25 @@ where

struct DropGuard<T> {
src: *const T,
dest: *mut T,
dst: *mut T,
}

impl<T> Drop for DropGuard<T> {
fn drop(&mut self) {
// SAFETY: `T` is not a zero-sized type, src must hold the original 8 elements of v in
// any order. And dest must be valid to write 8 elements.
// any order. And dst must be valid to write 8 elements.
unsafe {
ptr::copy_nonoverlapping(self.src, self.dest, 8);
ptr::copy_nonoverlapping(self.src, self.dst, 8);
}
}
}
}

#[inline(always)]
unsafe fn merge_up<T, F>(
mut src_left: *const T,
mut src_right: *const T,
mut out: *mut T,
mut left_src: *const T,
mut right_src: *const T,
mut dst: *mut T,
is_less: &mut F,
) -> (*const T, *const T, *mut T)
where
Expand All @@ -305,34 +305,34 @@ where
// This is a branchless merge utility function.
// The equivalent code with a branch would be:
//
// if !is_less(&*src_right, &*src_left) {
// ptr::copy_nonoverlapping(src_left, out, 1);
// src_left = src_left.wrapping_add(1);
// if !is_less(&*right_src, &*left_src) {
// ptr::copy_nonoverlapping(left_src, dst, 1);
// left_src = left_src.wrapping_add(1);
// } else {
// ptr::copy_nonoverlapping(src_right, out, 1);
// src_right = src_right.wrapping_add(1);
// ptr::copy_nonoverlapping(right_src, dst, 1);
// right_src = right_src.wrapping_add(1);
// }
// out = out.add(1);
// dst = dst.add(1);

// SAFETY: The caller must guarantee that `src_left`, `src_right` are valid to read and
// `out` is valid to write, while not aliasing.
// SAFETY: The caller must guarantee that `left_src`, `right_src` are valid to read and
// `dst` is valid to write, while not aliasing.
unsafe {
let is_l = !is_less(&*src_right, &*src_left);
let dest = if is_l { src_left } else { src_right };
ptr::copy_nonoverlapping(dest, out, 1);
src_right = src_right.wrapping_add(!is_l as usize);
src_left = src_left.wrapping_add(is_l as usize);
out = out.add(1);
let is_l = !is_less(&*right_src, &*left_src);
let src = if is_l { left_src } else { right_src };
ptr::copy_nonoverlapping(src, dst, 1);
right_src = right_src.wrapping_add(!is_l as usize);
left_src = left_src.wrapping_add(is_l as usize);
dst = dst.add(1);
}

(src_left, src_right, out)
(left_src, right_src, dst)
}

#[inline(always)]
unsafe fn merge_down<T, F>(
mut src_left: *const T,
mut src_right: *const T,
mut out: *mut T,
mut left_src: *const T,
mut right_src: *const T,
mut dst: *mut T,
is_less: &mut F,
) -> (*const T, *const T, *mut T)
where
Expand All @@ -341,97 +341,97 @@ where
// This is a branchless merge utility function.
// The equivalent code with a branch would be:
//
// if !is_less(&*src_right, &*src_left) {
// ptr::copy_nonoverlapping(src_right, out, 1);
// src_right = src_right.wrapping_sub(1);
// if !is_less(&*right_src, &*left_src) {
// ptr::copy_nonoverlapping(right_src, dst, 1);
// right_src = right_src.wrapping_sub(1);
// } else {
// ptr::copy_nonoverlapping(src_left, out, 1);
// src_left = src_left.wrapping_sub(1);
// ptr::copy_nonoverlapping(left_src, dst, 1);
// left_src = left_src.wrapping_sub(1);
// }
// out = out.sub(1);
// dst = dst.sub(1);

// SAFETY: The caller must guarantee that `src_left`, `src_right` are valid to read and
// `out` is valid to write, while not aliasing.
// SAFETY: The caller must guarantee that `left_src`, `right_src` are valid to read and
// `dst` is valid to write, while not aliasing.
unsafe {
let is_l = !is_less(&*src_right, &*src_left);
let dest = if is_l { src_right } else { src_left };
ptr::copy_nonoverlapping(dest, out, 1);
src_right = src_right.wrapping_sub(is_l as usize);
src_left = src_left.wrapping_sub(!is_l as usize);
out = out.sub(1);
let is_l = !is_less(&*right_src, &*left_src);
let src = if is_l { right_src } else { left_src };
ptr::copy_nonoverlapping(src, dst, 1);
right_src = right_src.wrapping_sub(is_l as usize);
left_src = left_src.wrapping_sub(!is_l as usize);
dst = dst.sub(1);
}

(src_left, src_right, out)
(left_src, right_src, dst)
}

/// Merge v assuming the len is even and v[..len / 2] and v[len / 2..] are sorted.
///
/// Original idea for bi-directional merging by Igor van den Hoven (quadsort), adapted to only use
/// merge up and down. In contrast to the original parity_merge function, it performs 2 writes
/// instead of 4 per iteration. Ord violation detection was added.
unsafe fn bi_directional_merge_even<T, F>(v: &[T], out: *mut T, is_less: &mut F)
unsafe fn bi_directional_merge_even<T, F>(v: &[T], dst: *mut T, is_less: &mut F)
where
T: crate::Freeze,
F: FnMut(&T, &T) -> bool,
{
// SAFETY: the caller must guarantee that `out` is valid for v.len() writes.
// Also `v.as_ptr` and `out` must not alias.
// SAFETY: the caller must guarantee that `dst` is valid for v.len() writes.
// Also `v.as_ptr` and `dst` must not alias.
//
// The caller must guarantee that T cannot modify itself inside is_less.
// merge_up and merge_down read left and right pointers and potentially modify the stack value
// they point to, if T has interior mutability. This may leave one or two potential writes to
// the stack value un-observed when dest is copied onto of src.
// the stack value un-observed when dst is copied onto of src.

// It helps to visualize the merge:
//
// Initial:
//
// |out (in dest)
// |left |right
// |dst (in dst)
// |left |right
// v v
// [xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx]
// ^ ^
// |rev_left |rev_right
// |rev_out (in dest)
// |left_rev |right_rev
// |dst_rev (in dst)
//
// After:
//
// |out (in dest)
// |left | |right
// |dst (in dst)
// |left | |right
// v v v
// [xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx]
// ^ ^ ^
// |rev_left | |rev_right
// |rev_out (in dest)
// |left_rev | |right_rev
// |dst_rev (in dst)
//
//
// Note, the pointers that have been written, are now one past where they were read and
// copied. written == incremented or decremented + copy to dest.
// copied. written == incremented or decremented + copy to dst.

let len = v.len();
let src = v.as_ptr();

let len_div_2 = len / 2;

// SAFETY: No matter what the result of the user-provided comparison function is, all 4 read
// pointers will always be in-bounds. Writing `out` and `rev_out` will always be in
// bounds if the caller guarantees that `out` is valid for `v.len()` writes.
// pointers will always be in-bounds. Writing `dst` and `dst_rev` will always be in
// bounds if the caller guarantees that `dst` is valid for `v.len()` writes.
unsafe {
let mut left = src;
let mut right = src.wrapping_add(len_div_2);
let mut out = out;
let mut dst = dst;

let mut rev_left = src.wrapping_add(len_div_2 - 1);
let mut rev_right = src.wrapping_add(len - 1);
let mut rev_out = out.wrapping_add(len - 1);
let mut left_rev = src.wrapping_add(len_div_2 - 1);
let mut right_rev = src.wrapping_add(len - 1);
let mut dst_rev = dst.wrapping_add(len - 1);

for _ in 0..len_div_2 {
(left, right, out) = merge_up(left, right, out, is_less);
(rev_left, rev_right, rev_out) = merge_down(rev_left, rev_right, rev_out, is_less);
(left, right, dst) = merge_up(left, right, dst, is_less);
(left_rev, right_rev, dst_rev) = merge_down(left_rev, right_rev, dst_rev, is_less);
}

let left_diff = (left as usize).wrapping_sub(rev_left as usize);
let right_diff = (right as usize).wrapping_sub(rev_right as usize);
let left_diff = (left as usize).wrapping_sub(left_rev as usize);
let right_diff = (right as usize).wrapping_sub(right_rev as usize);

if !(left_diff == mem::size_of::<T>() && right_diff == mem::size_of::<T>()) {
panic_on_ord_violation();
Expand Down

0 comments on commit a2bc908

Please sign in to comment.