Skip to content

Commit

Permalink
Apply review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Voultapher committed Aug 16, 2023
1 parent ba52356 commit e79ed09
Showing 1 changed file with 21 additions and 37 deletions.
58 changes: 21 additions & 37 deletions src/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use core::ptr;

/// Merges non-decreasing runs `v[..mid]` and `v[mid..]` using `buf` as temporary storage, and
/// stores the result into `v[..]`.
#[inline(never)]
pub fn merge<T, F>(v: &mut [T], scratch: &mut [MaybeUninit<T>], mid: usize, is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
Expand All @@ -16,11 +15,10 @@ where
intrinsics::abort();
}

// SAFETY: We checked that the two slices must be non-empty and `mid` must be in bounds. The
// caller has to guarantee that Buffer `buf` must be long enough to hold a copy of the shorter
// slice. Also, `T` must not be a zero-sized type. We checked that T is observation safe. Should
// is_less panic v was not modified in bi_directional_merge and retains it's original input.
// buffer and v must not alias and swap has v.len() space.
// SAFETY: We checked that the two slices are non-empty and `mid` is in bounds. We checked that
// the Buffer `scratch` has enough capacity to hold a copy of the shorter slice. `merge_up` and
// `merge_down` are written in such a way that they uphold the contract described in
// `MergeState::drop`.
unsafe {
// The merge process first copies the shorter run into `buf`. Then it traces the newly
// copied run and the longer run forwards (or backwards), comparing their next unconsumed
Expand All @@ -33,12 +31,6 @@ where
// Intermediate state of the process is always tracked by `gap`, which serves two purposes:
// 1. Protects integrity of `v` from panics in `is_less`.
// 2. Fills the remaining gap in `v` if the longer run gets consumed first.
//
// Panic safety:
//
// If `is_less` panics at any point during the process, `gap` will get dropped and fill the
// gap in `v` with the unconsumed range in `buf`, thus ensuring that `v` still holds every
// object it initially held exactly once.

let buf = MaybeUninit::slice_as_mut_ptr(scratch);

Expand All @@ -56,23 +48,15 @@ where

ptr::copy_nonoverlapping(save_base, buf, save_len);

let mut merge_state;
let mut merge_state = MergeState {
start: buf,
end: buf.add(save_len),
dst: save_base,
};

if left_is_shorter {
merge_state = MergeState {
start: buf,
end: buf.add(mid),
dst: v_base,
};

merge_state.merge_up(v_mid, v_end, is_less);
} else {
merge_state = MergeState {
start: buf,
end: buf.add(len - mid),
dst: v_mid,
};

merge_state.merge_down(v_base, buf, v_end, is_less);
}
// Finally, `merge_state` gets dropped. If the shorter run was not fully consumed, whatever
Expand All @@ -93,19 +77,19 @@ where
right_end: *const T,
is_less: &mut F,
) {
// left == self.start
// out == self.dst
let left = &mut self.start;
let out = &mut self.dst;

while self.start != self.end && right as *const T != right_end {
let consume_left = !is_less(&*right, &*self.start);
while *left != self.end && right as *const T != right_end {
let consume_left = !is_less(&*right, &**left);

let src = if consume_left { self.start } else { right };
ptr::copy_nonoverlapping(src, self.dst, 1);
let src = if consume_left { *left } else { right };
ptr::copy_nonoverlapping(src, *out, 1);

self.start = self.start.add(consume_left as usize);
*left = left.add(consume_left as usize);
right = right.add(!consume_left as usize);

self.dst = self.dst.add(1);
*out = out.add(1);
}
}

Expand All @@ -116,9 +100,6 @@ where
mut out: *mut T,
is_less: &mut F,
) {
// left == self.dst;
// right == self.end;

loop {
let left = self.dst.sub(1);
let right = self.end.sub(1);
Expand All @@ -141,7 +122,10 @@ where

impl<T> Drop for MergeState<T> {
fn drop(&mut self) {
// SAFETY: `T` is not a zero-sized type, and these are pointers into a slice's elements.
// SAFETY: The user of MergeState MUST ensure, that at any point this drop impl MAY run,
// for example when the user provided `is_less` panics, that copying the contiguous
// region between `start` and `end` to `dst` will leave the input slice `v` with each
// original element and all possible modifications observed.
unsafe {
let len = self.end.sub_ptr(self.start);
ptr::copy_nonoverlapping(self.start, self.dst, len);
Expand Down

0 comments on commit e79ed09

Please sign in to comment.