Skip to content

Commit

Permalink
Refine merge implementation
Browse files Browse the repository at this point in the history
- Use reliable branchless logic for both sides
- Single memcpy instead of one for both branches
- Improve merging down code-gen on x86 and Arm
  • Loading branch information
Voultapher committed Aug 2, 2023
1 parent 14c7b7a commit ba52356
Showing 1 changed file with 110 additions and 131 deletions.
241 changes: 110 additions & 131 deletions src/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,134 +5,141 @@ use core::ptr;

/// Merges non-decreasing runs `v[..mid]` and `v[mid..]` using `buf` as temporary storage, and
/// stores the result into `v[..]`.
///
/// # Safety
///
/// The two slices must be non-empty and `mid` must be in bounds. Buffer `buf` must be long enough
/// to hold a copy of the shorter slice. Also, `T` must not be a zero-sized type.
///
/// Never inline this function to avoid code bloat. It still optimizes nicely and has practically no
/// performance impact.
#[inline(never)]
unsafe fn merge_fallback<T, F>(v: &mut [T], mid: usize, buf: *mut T, is_less: &mut F)
pub fn merge<T, F>(v: &mut [T], scratch: &mut [MaybeUninit<T>], mid: usize, is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
let len = v.len();
let v = v.as_mut_ptr();

// SAFETY: mid and len must be in-bounds of v.
let (v_mid, v_end) = unsafe { (v.add(mid), v.add(len)) };

// The merge process first copies the shorter run into `buf`. Then it traces the newly copied
// run and the longer run forwards (or backwards), comparing their next unconsumed elements and
// copying the lesser (or greater) one into `v`.
//
// As soon as the shorter run is fully consumed, the process is done. If the longer run gets
// consumed first, then we must copy whatever is left of the shorter run into the remaining
// hole in `v`.
//
// Intermediate state of the process is always tracked by `hole`, which serves two purposes:
// 1. Protects integrity of `v` from panics in `is_less`.
// 2. Fills the remaining hole in `v` if the longer run gets consumed first.
//
// Panic safety:
//
// If `is_less` panics at any point during the process, `hole` will get dropped and fill the
// hole in `v` with the unconsumed range in `buf`, thus ensuring that `v` still holds every
// object it initially held exactly once.
let mut hole;

if mid <= len - mid {
// The left run is shorter.

// SAFETY: buf must have enough capacity for `v[..mid]`.
unsafe {
ptr::copy_nonoverlapping(v, buf, mid);
hole = MergeHole {

if mid == 0 || mid >= len || scratch.len() < cmp::min(mid, len - mid) {
intrinsics::abort();
}

// SAFETY: We checked that the two slices must be non-empty and `mid` must be in bounds. The
// caller has to guarantee that Buffer `buf` must be long enough to hold a copy of the shorter
// slice. Also, `T` must not be a zero-sized type. We checked that T is observation safe. Should
// is_less panic v was not modified in bi_directional_merge and retains it's original input.
// buffer and v must not alias and swap has v.len() space.
unsafe {
// The merge process first copies the shorter run into `buf`. Then it traces the newly
// copied run and the longer run forwards (or backwards), comparing their next unconsumed
// elements and copying the lesser (or greater) one into `v`.
//
// As soon as the shorter run is fully consumed, the process is done. If the longer run gets
// consumed first, then we must copy whatever is left of the shorter run into the remaining
// gap in `v`.
//
// Intermediate state of the process is always tracked by `gap`, which serves two purposes:
// 1. Protects integrity of `v` from panics in `is_less`.
// 2. Fills the remaining gap in `v` if the longer run gets consumed first.
//
// Panic safety:
//
// If `is_less` panics at any point during the process, `gap` will get dropped and fill the
// gap in `v` with the unconsumed range in `buf`, thus ensuring that `v` still holds every
// object it initially held exactly once.

let buf = MaybeUninit::slice_as_mut_ptr(scratch);

let v_base = v.as_mut_ptr();
let v_mid = v_base.add(mid);
let v_end = v_base.add(len);

let left_len = mid;
let right_len = len - mid;

let left_is_shorter = left_len <= right_len;

let save_base = if left_is_shorter { v_base } else { v_mid };
let save_len = if left_is_shorter { left_len } else { right_len };

ptr::copy_nonoverlapping(save_base, buf, save_len);

let mut merge_state;

if left_is_shorter {
merge_state = MergeState {
start: buf,
end: buf.add(mid),
dst: v,
dst: v_base,
};
}

// Initially, these pointers point to the beginnings of their arrays.
let left = &mut hole.start;
let mut right = v_mid;
let out = &mut hole.dst;

while *left < hole.end && right < v_end {
// Consume the lesser side.
// If equal, prefer the left run to maintain stability.

// SAFETY: left and right must be valid and part of v same for out.
unsafe {
let to_copy = if is_less(&*right, &**left) {
get_and_increment(&mut right)
} else {
get_and_increment(left)
};
ptr::copy_nonoverlapping(to_copy, get_and_increment(out), 1);
}
}
} else {
// The right run is shorter.

// SAFETY: buf must have enough capacity for `v[mid..]`.
unsafe {
ptr::copy_nonoverlapping(v_mid, buf, len - mid);
hole = MergeHole {
merge_state.merge_up(v_mid, v_end, is_less);
} else {
merge_state = MergeState {
start: buf,
end: buf.add(len - mid),
dst: v_mid,
};

merge_state.merge_down(v_base, buf, v_end, is_less);
}
// Finally, `merge_state` gets dropped. If the shorter run was not fully consumed, whatever
// remains of it will now be copied into the hole in `v`.
}

// Initially, these pointers point past the ends of their arrays.
let left = &mut hole.dst;
let right = &mut hole.end;
let mut out = v_end;
// When dropped, copies the range `start..end` into `dst..`.
struct MergeState<T> {
start: *mut T,
end: *mut T,
dst: *mut T,
}

while v < *left && buf < *right {
// Consume the greater side.
// If equal, prefer the right run to maintain stability.
impl<T> MergeState<T> {
unsafe fn merge_up<F: FnMut(&T, &T) -> bool>(
&mut self,
mut right: *mut T,
right_end: *const T,
is_less: &mut F,
) {
// left == self.start
// out == self.dst

// SAFETY: left and right must be valid and part of v same for out.
unsafe {
let to_copy = if is_less(&*right.sub(1), &*left.sub(1)) {
decrement_and_get(left)
} else {
decrement_and_get(right)
};
ptr::copy_nonoverlapping(to_copy, decrement_and_get(&mut out), 1);
while self.start != self.end && right as *const T != right_end {
let consume_left = !is_less(&*right, &*self.start);

let src = if consume_left { self.start } else { right };
ptr::copy_nonoverlapping(src, self.dst, 1);

self.start = self.start.add(consume_left as usize);
right = right.add(!consume_left as usize);

self.dst = self.dst.add(1);
}
}
}
// Finally, `hole` gets dropped. If the shorter run was not fully consumed, whatever remains of
// it will now be copied into the hole in `v`.

unsafe fn get_and_increment<T>(ptr: &mut *mut T) -> *mut T {
let old = *ptr;
unsafe fn merge_down<F: FnMut(&T, &T) -> bool>(
&mut self,
left_end: *const T,
right_end: *const T,
mut out: *mut T,
is_less: &mut F,
) {
// left == self.dst;
// right == self.end;

// SAFETY: ptr.add(1) must still be a valid pointer and part of `v`.
*ptr = unsafe { ptr.add(1) };
old
}
loop {
let left = self.dst.sub(1);
let right = self.end.sub(1);
out = out.sub(1);

unsafe fn decrement_and_get<T>(ptr: &mut *mut T) -> *mut T {
// SAFETY: ptr.sub(1) must still be a valid pointer and part of `v`.
*ptr = unsafe { ptr.sub(1) };
*ptr
}
let consume_left = is_less(&*right, &*left);

// When dropped, copies the range `start..end` into `dst..`.
struct MergeHole<T> {
start: *mut T,
end: *mut T,
dst: *mut T,
let src = if consume_left { left } else { right };
ptr::copy_nonoverlapping(src, out, 1);

self.dst = left.add(!consume_left as usize);
self.end = right.add(consume_left as usize);

if self.dst as *const T == left_end || self.end as *const T == right_end {
break;
}
}
}
}

impl<T> Drop for MergeHole<T> {
impl<T> Drop for MergeState<T> {
fn drop(&mut self) {
// SAFETY: `T` is not a zero-sized type, and these are pointers into a slice's elements.
unsafe {
Expand All @@ -142,31 +149,3 @@ where
}
}
}

/// Merges non-decreasing runs `v[..mid]` and `v[mid..]` using `buf` as temporary storage, and
/// stores the result into `v[..]`.
///
/// # Safety
///
/// Buffer as pointed to by `buffer` must have space for `buf_len` writes. And must not alias `v`.
#[inline(always)]
pub fn merge<T, F>(v: &mut [T], scratch: &mut [MaybeUninit<T>], mid: usize, is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
let len = v.len();

if mid == 0 || mid >= len || scratch.len() < cmp::min(mid, len - mid) {
intrinsics::abort();
}

// SAFETY: We checked that the two slices must be non-empty and `mid` must be in bounds. The
// caller has to guarantee that Buffer `buf` must be long enough to hold a copy of the shorter
// slice. Also, `T` must not be a zero-sized type. We checked that T is observation safe. Should
// is_less panic v was not modified in bi_directional_merge and retains it's original input.
// buffer and v must not alias and swap has v.len() space.
unsafe {
let buffer = MaybeUninit::slice_as_mut_ptr(scratch);
merge_fallback(v, mid, buffer, is_less);
}
}

0 comments on commit ba52356

Please sign in to comment.