Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change to simpler merge algorithm. #22

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ where
// - alloc len elements up to MAX_FULL_ALLOC_BYTES
// - alloc len / 2 elements
// This allows us to use the most performant algorithms for small-medium
// sized inputs while scaling down to len / 2 for larger inputs. We need at
// least len / 2 for our stable merging routine.
// sized inputs while scaling down to len / 2 for larger inputs.
const MAX_FULL_ALLOC_BYTES: usize = 8_000_000;
let len = v.len();
let half = len / 2 + 1; // Add one such that for odd sizes either half fits.
let full_alloc_size = cmp::min(len, MAX_FULL_ALLOC_BYTES / mem::size_of::<T>());
let alloc_size = cmp::max(len / 2, full_alloc_size);
let alloc_size = cmp::max(half, full_alloc_size);

let mut buf = BufT::with_capacity(alloc_size);
let scratch_slice =
Expand Down
219 changes: 62 additions & 157 deletions src/merge.rs
Original file line number Diff line number Diff line change
@@ -1,172 +1,77 @@
use core::cmp;
use core::intrinsics;
use core::mem::MaybeUninit;
use core::ptr;

/// Merges non-decreasing runs `v[..mid]` and `v[mid..]` using `buf` as temporary storage, and
/// stores the result into `v[..]`.
///
/// # Safety
///
/// The two slices must be non-empty and `mid` must be in bounds. Buffer `buf` must be long enough
/// to hold a copy of the shorter slice. Also, `T` must not be a zero-sized type.
///
/// Never inline this function to avoid code bloat. It still optimizes nicely and has practically no
/// performance impact.
#[inline(never)]
unsafe fn merge_fallback<T, F>(v: &mut [T], mid: usize, buf: *mut T, is_less: &mut F)
/// stores the result into `v[..]`. Does O(v.len()) comparisons and
/// O(v.len() * (1 + v.len() / scratch.len())) moves.
#[inline(always)]
pub fn merge<T, F>(v: &mut [T], scratch: &mut [MaybeUninit<T>], mid: usize, is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
let len = v.len();
let v = v.as_mut_ptr();

// SAFETY: mid and len must be in-bounds of v.
let (v_mid, v_end) = unsafe { (v.add(mid), v.add(len)) };

// The merge process first copies the shorter run into `buf`. Then it traces the newly copied
// run and the longer run forwards (or backwards), comparing their next unconsumed elements and
// copying the lesser (or greater) one into `v`.
//
// As soon as the shorter run is fully consumed, the process is done. If the longer run gets
// consumed first, then we must copy whatever is left of the shorter run into the remaining
// hole in `v`.
//
// Intermediate state of the process is always tracked by `hole`, which serves two purposes:
// 1. Protects integrity of `v` from panics in `is_less`.
// 2. Fills the remaining hole in `v` if the longer run gets consumed first.
//
// Panic safety:
//
// If `is_less` panics at any point during the process, `hole` will get dropped and fill the
// hole in `v` with the unconsumed range in `buf`, thus ensuring that `v` still holds every
// object it initially held exactly once.
let mut hole;

if mid <= len - mid {
// The left run is shorter.

// SAFETY: buf must have enough capacity for `v[..mid]`.
unsafe {
ptr::copy_nonoverlapping(v, buf, mid);
hole = MergeHole {
start: buf,
end: buf.add(mid),
dst: v,
};
}

// Initially, these pointers point to the beginnings of their arrays.
let left = &mut hole.start;
let mut right = v_mid;
let out = &mut hole.dst;

while *left < hole.end && right < v_end {
// Consume the lesser side.
// If equal, prefer the left run to maintain stability.

// SAFETY: left and right must be valid and part of v same for out.
unsafe {
let to_copy = if is_less(&*right, &**left) {
get_and_increment(&mut right)
} else {
get_and_increment(left)
};
ptr::copy_nonoverlapping(to_copy, get_and_increment(out), 1);
}
}
} else {
// The right run is shorter.

// SAFETY: buf must have enough capacity for `v[mid..]`.
unsafe {
ptr::copy_nonoverlapping(v_mid, buf, len - mid);
hole = MergeHole {
start: buf,
end: buf.add(len - mid),
dst: v_mid,
};
}

// Initially, these pointers point past the ends of their arrays.
let left = &mut hole.dst;
let right = &mut hole.end;
let mut out = v_end;
let v_base = v.as_mut_ptr();
let scratch_len = scratch.len();
let scratch_base = MaybeUninit::slice_as_mut_ptr(scratch);

while v < *left && buf < *right {
// Consume the greater side.
// If equal, prefer the right run to maintain stability.

// SAFETY: left and right must be valid and part of v same for out.
unsafe {
let to_copy = if is_less(&*right.sub(1), &*left.sub(1)) {
decrement_and_get(left)
} else {
decrement_and_get(right)
};
ptr::copy_nonoverlapping(to_copy, decrement_and_get(&mut out), 1);
unsafe {
// SAFETY
// The scratch and element array respectively have the following layouts:
//
// | merged elements | free space |
// ^ scratch_base ^ scratch_out ^ scratch_end
//
// | merged elements | gap | unmerged left | gap | unmerged right |
// ^ v_base ^ merged_out ^ left ^ left_end ^ right ^ v_end
//
// Note that the 'gaps' here are purely logical, not physical. We
// strictly copy from the element array to the scratch, and leave the
// input array completely untouched, should a panic occur. Only when we
// are done or the scratch buffer is full do we copy back the merged
// elements into the source array, closing the gaps. This is a panicless
// procedure, and thus safe. We never call the comparison operator again
// on any element that was copied, so interior mutability is not a problem.
let scratch_end = scratch_base.add(scratch_len);
let v_end = v_base.add(len);

let mut left = v_base;
let mut left_end = left.add(mid);
let mut right = left_end;
let mut scratch_out = scratch_base;
let mut merged_out = v_base;
let mut merge_done = false;

while !merge_done {
// Fill the scratch space with merged elements.
let free_scratch_space = scratch_end.sub_ptr(scratch_out);
let left_len = left_end.sub_ptr(left);
let right_len = v_end.sub_ptr(right);
let safe_iters = free_scratch_space.min(left_len).min(right_len);
for _ in 0..safe_iters {
let right_less = is_less(&*right, &*left);
let src = if right_less { right } else { left };
ptr::copy_nonoverlapping(src, scratch_out, 1);

scratch_out = scratch_out.add(1);
left = left.add((!right_less) as usize);
right = right.add(right_less as usize);
}
}
}
// Finally, `hole` gets dropped. If the shorter run was not fully consumed, whatever remains of
// it will now be copied into the hole in `v`.

unsafe fn get_and_increment<T>(ptr: &mut *mut T) -> *mut T {
let old = *ptr;

// SAFETY: ptr.add(1) must still be a valid pointer and part of `v`.
*ptr = unsafe { ptr.add(1) };
old
}

unsafe fn decrement_and_get<T>(ptr: &mut *mut T) -> *mut T {
// SAFETY: ptr.sub(1) must still be a valid pointer and part of `v`.
*ptr = unsafe { ptr.sub(1) };
*ptr
}

// When dropped, copies the range `start..end` into `dst..`.
struct MergeHole<T> {
start: *mut T,
end: *mut T,
dst: *mut T,
}

impl<T> Drop for MergeHole<T> {
fn drop(&mut self) {
// SAFETY: `T` is not a zero-sized type, and these are pointers into a slice's elements.
unsafe {
let len = self.end.sub_ptr(self.start);
ptr::copy_nonoverlapping(self.start, self.dst, len);
merge_done = left == left_end || right == v_end;
if scratch_out == scratch_end || merge_done {
// Move the remaining left elements next to the right elements.
let new_left_len = left_end.sub_ptr(left);
let new_left = right.sub(new_left_len);
ptr::copy(left, new_left, new_left_len);
left = new_left;
left_end = left.add(new_left_len);

// Move merged elements in scratch back to v and reset scratch.
let merged_n = scratch_out.sub_ptr(scratch_base);
ptr::copy_nonoverlapping(scratch_base, merged_out, merged_n);
merged_out = merged_out.add(merged_n);
scratch_out = scratch_base;
}
}
}
}

/// Merges non-decreasing runs `v[..mid]` and `v[mid..]` using `buf` as temporary storage, and
/// stores the result into `v[..]`.
///
/// # Safety
///
/// Buffer as pointed to by `buffer` must have space for `buf_len` writes. And must not alias `v`.
#[inline(always)]
pub fn merge<T, F>(v: &mut [T], scratch: &mut [MaybeUninit<T>], mid: usize, is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
let len = v.len();

if mid == 0 || mid >= len || scratch.len() < cmp::min(mid, len - mid) {
intrinsics::abort();
}

// SAFETY: We checked that the two slices must be non-empty and `mid` must be in bounds. The
// caller has to guarantee that Buffer `buf` must be long enough to hold a copy of the shorter
// slice. Also, `T` must not be a zero-sized type. We checked that T is observation safe. Should
// is_less panic v was not modified in bi_directional_merge and retains it's original input.
// buffer and v must not alias and swap has v.len() space.
unsafe {
let buffer = MaybeUninit::slice_as_mut_ptr(scratch);
merge_fallback(v, mid, buffer, is_less);
}
}
Loading