From c7cf64211a20fd9de9ce34a82e663358195bb42a Mon Sep 17 00:00:00 2001 From: Lukas Bergdoll Date: Sat, 4 Nov 2023 18:18:03 +0100 Subject: [PATCH 1/3] Improve manual partition loop unrolling --- src/quicksort.rs | 39 +++++++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/src/quicksort.rs b/src/quicksort.rs index ee0cab3..cf69d39 100644 --- a/src/quicksort.rs +++ b/src/quicksort.rs @@ -203,7 +203,7 @@ impl StablePartitionTypeImpl for T { impl StablePartitionTypeImpl for T where T: Copy, - (): crate::IsTrue<{ mem::size_of::() <= (mem::size_of::() * 2) }>, + (): crate::IsTrue<{ mem::size_of::() <= 16 }>, { /// See [`StablePartitionTypeImpl::partition_fill_scratch`]. fn partition_fill_scratch bool>( @@ -220,6 +220,13 @@ where core::intrinsics::abort() } + // Manually unrolled as micro-optimization as only x86 gets auto-unrolling but not Arm. + let unroll_len = if const { mem::size_of::() <= 8 } { + 4 + } else { + 2 + }; + unsafe { // SAFETY: exactly the same invariants and logic as the // non-specialized impl. The conditional store being replaced by a @@ -229,12 +236,13 @@ where // naive loop unrolling where the exact same loop body is just // repeated. let pivot = v_base.add(pivot_pos); + let mut scan = v_base; let mut pivot_in_scratch = ptr::null_mut(); let mut num_lt = 0; let mut scratch_rev = scratch_base.add(len); + macro_rules! loop_body { - ($i:expr) => { - let scan = v_base.add($i); + () => {{ scratch_rev = scratch_rev.sub(1); let is_less_than_pivot = is_less(&*scan, &*pivot); @@ -253,24 +261,23 @@ where } num_lt += is_less_than_pivot as usize; - }; + scan = scan.add(1); + _ = scan; + }}; } - /// To ensure good performance across platforms we explicitly unroll using a - /// fixed-size inner loop. We do not simply call the loop body multiple times as - /// this increases compile times significantly more, and the compiler unrolls - /// a fixed loop just as well, if it is sensible. - const UNROLL_LEN: usize = 4; - let mut offset = 0; - for _ in 0..(len / UNROLL_LEN) { - for unroll_i in 0..UNROLL_LEN { - loop_body!(offset + unroll_i); + // We do not simply call the loop body multiple times as this increases compile + // times significantly more, and the compiler unrolls a fixed loop just as well, if it + // is sensible. + let unroll_end = v_base.add(len - (unroll_len - 1)); + while scan < unroll_end { + for _ in 0..unroll_len { + loop_body!(); } - offset += UNROLL_LEN; } - for i in 0..(len % UNROLL_LEN) { - loop_body!(offset + i); + while scan < v_base.add(len) { + loop_body!(); } // SAFETY: if T has interior mutability our copy in scratch can be From 4f3cc11bb9dffb2356e9f8c9329b03e927482063 Mon Sep 17 00:00:00 2001 From: Lukas Bergdoll Date: Thu, 9 Nov 2023 17:47:48 +0100 Subject: [PATCH 2/3] Use manual unroll approach not prone to pessimisation with opt-level=s --- src/quicksort.rs | 71 +++++++++++++++++++++++++++--------------------- 1 file changed, 40 insertions(+), 31 deletions(-) diff --git a/src/quicksort.rs b/src/quicksort.rs index cf69d39..9aa5993 100644 --- a/src/quicksort.rs +++ b/src/quicksort.rs @@ -193,6 +193,43 @@ impl StablePartitionTypeImpl for T { } } +/// This construct works around a couple of issues with auto unrolling as well as manual unrolling. +/// Auto unrolling as tested with rustc 1.75 is somewhat run-time and binary-size inefficient, +/// because it performs additional math to calculate the loop end, which we can avoid by +/// precomputing the loop end. Also auto unrolling only happens on x86 but not on Arm where doing so +/// for the Firestorm micro-architecture yields a 15+% performance improvement. Manual unrolling via +/// repeated code has a large negative impact on debug compile-times, and unrolling via `for _ in +/// 0..UNROLL_LEN` has a 10-20% perf penalty when compiling with `opt-level=s` which is deemed +/// unacceptable for such a crucial component of the sort implementation. +trait UnrollHelper: Sized { + const UNROLL_LEN: usize; + + unsafe fn unrolled_loop_body(loop_body: F); +} + +impl UnrollHelper for T { + default const UNROLL_LEN: usize = 2; + + default unsafe fn unrolled_loop_body(mut loop_body: F) { + loop_body(); + loop_body(); + } +} + +impl UnrollHelper for T +where + (): crate::IsTrue<{ mem::size_of::() <= 8 }>, +{ + const UNROLL_LEN: usize = 4; + + unsafe fn unrolled_loop_body(mut loop_body: F) { + loop_body(); + loop_body(); + loop_body(); + loop_body(); + } +} + /// Specialization for small types, through traits to not invoke compile time /// penalties for loop unrolling when not used. /// @@ -202,7 +239,7 @@ impl StablePartitionTypeImpl for T { /// store. And explicit loop unrolling is also often very beneficial. impl StablePartitionTypeImpl for T where - T: Copy, + T: Copy + crate::Freeze, (): crate::IsTrue<{ mem::size_of::() <= 16 }>, { /// See [`StablePartitionTypeImpl::partition_fill_scratch`]. @@ -220,13 +257,6 @@ where core::intrinsics::abort() } - // Manually unrolled as micro-optimization as only x86 gets auto-unrolling but not Arm. - let unroll_len = if const { mem::size_of::() <= 8 } { - 4 - } else { - 2 - }; - unsafe { // SAFETY: exactly the same invariants and logic as the // non-specialized impl. The conditional store being replaced by a @@ -237,7 +267,6 @@ where // repeated. let pivot = v_base.add(pivot_pos); let mut scan = v_base; - let mut pivot_in_scratch = ptr::null_mut(); let mut num_lt = 0; let mut scratch_rev = scratch_base.add(len); @@ -249,43 +278,23 @@ where ptr::copy_nonoverlapping(scan, scratch_base.add(num_lt), 1); ptr::copy_nonoverlapping(scan, scratch_rev.add(num_lt), 1); - // Save pivot location in scratch for later. - if const { crate::has_direct_interior_mutability::() } - && intrinsics::unlikely(scan as *const T == pivot) - { - pivot_in_scratch = if is_less_than_pivot { - scratch_base.add(num_lt) // i + num_lt - } else { - scratch_rev.add(num_lt) // len - (i + 1) + num_lt = len - 1 - num_ge - }; - } - num_lt += is_less_than_pivot as usize; scan = scan.add(1); - _ = scan; }}; } // We do not simply call the loop body multiple times as this increases compile // times significantly more, and the compiler unrolls a fixed loop just as well, if it // is sensible. - let unroll_end = v_base.add(len - (unroll_len - 1)); + let unroll_end = v_base.add(len - (T::UNROLL_LEN - 1)); while scan < unroll_end { - for _ in 0..unroll_len { - loop_body!(); - } + T::unrolled_loop_body(|| loop_body!()); } while scan < v_base.add(len) { loop_body!(); } - // SAFETY: if T has interior mutability our copy in scratch can be - // outdated, update it. - if const { crate::has_direct_interior_mutability::() } { - ptr::copy_nonoverlapping(pivot, pivot_in_scratch, 1); - } - num_lt } } From 54b89fe82467e611a8ce56e604984249e7f0431f Mon Sep 17 00:00:00 2001 From: Lukas Bergdoll Date: Fri, 10 Nov 2023 17:18:52 +0100 Subject: [PATCH 3/3] Apply review comments --- src/quicksort.rs | 76 +++++++++++++++++++++++++++++++----------------- 1 file changed, 49 insertions(+), 27 deletions(-) diff --git a/src/quicksort.rs b/src/quicksort.rs index 9aa5993..2ce8d4e 100644 --- a/src/quicksort.rs +++ b/src/quicksort.rs @@ -193,6 +193,16 @@ impl StablePartitionTypeImpl for T { } } +struct PartitionState { + // The current element that is being looked at, scans left to right through slice. + scan: *const T, + // Counts the number of elements that compared less-than, also works around: + // https://github.com/rust-lang/rust/issues/117128 + num_lt: usize, + // Reverse scratch output pointer. + scratch_rev: *mut T, +} + /// This construct works around a couple of issues with auto unrolling as well as manual unrolling. /// Auto unrolling as tested with rustc 1.75 is somewhat run-time and binary-size inefficient, /// because it performs additional math to calculate the loop end, which we can avoid by @@ -204,15 +214,22 @@ impl StablePartitionTypeImpl for T { trait UnrollHelper: Sized { const UNROLL_LEN: usize; - unsafe fn unrolled_loop_body(loop_body: F); + unsafe fn unrolled_loop_body)>( + loop_body: F, + state: &mut PartitionState, + ); } impl UnrollHelper for T { default const UNROLL_LEN: usize = 2; - default unsafe fn unrolled_loop_body(mut loop_body: F) { - loop_body(); - loop_body(); + #[inline(always)] + default unsafe fn unrolled_loop_body)>( + mut loop_body: F, + state: &mut PartitionState, + ) { + loop_body(state); + loop_body(state); } } @@ -222,11 +239,15 @@ where { const UNROLL_LEN: usize = 4; - unsafe fn unrolled_loop_body(mut loop_body: F) { - loop_body(); - loop_body(); - loop_body(); - loop_body(); + #[inline(always)] + unsafe fn unrolled_loop_body)>( + mut loop_body: F, + state: &mut PartitionState, + ) { + loop_body(state); + loop_body(state); + loop_body(state); + loop_body(state); } } @@ -266,36 +287,37 @@ where // naive loop unrolling where the exact same loop body is just // repeated. let pivot = v_base.add(pivot_pos); - let mut scan = v_base; - let mut num_lt = 0; - let mut scratch_rev = scratch_base.add(len); - macro_rules! loop_body { - () => {{ - scratch_rev = scratch_rev.sub(1); + let mut loop_body = |state: &mut PartitionState| { + state.scratch_rev = state.scratch_rev.sub(1); - let is_less_than_pivot = is_less(&*scan, &*pivot); - ptr::copy_nonoverlapping(scan, scratch_base.add(num_lt), 1); - ptr::copy_nonoverlapping(scan, scratch_rev.add(num_lt), 1); + let is_less_than_pivot = is_less(&*state.scan, &*pivot); + ptr::copy_nonoverlapping(state.scan, scratch_base.add(state.num_lt), 1); + ptr::copy_nonoverlapping(state.scan, state.scratch_rev.add(state.num_lt), 1); - num_lt += is_less_than_pivot as usize; - scan = scan.add(1); - }}; - } + state.num_lt += is_less_than_pivot as usize; + state.scan = state.scan.add(1); + }; + + let mut state = PartitionState { + scan: v_base, + num_lt: 0, + scratch_rev: scratch_base.add(len), + }; // We do not simply call the loop body multiple times as this increases compile // times significantly more, and the compiler unrolls a fixed loop just as well, if it // is sensible. let unroll_end = v_base.add(len - (T::UNROLL_LEN - 1)); - while scan < unroll_end { - T::unrolled_loop_body(|| loop_body!()); + while state.scan < unroll_end { + T::unrolled_loop_body(&mut loop_body, &mut state); } - while scan < v_base.add(len) { - loop_body!(); + while state.scan < v_base.add(len) { + loop_body(&mut state); } - num_lt + state.num_lt } } }