Skip to content

Commit

Permalink
Add k_smallest_relaxed and variants
Browse files Browse the repository at this point in the history
This implements the algorithm described in [1] which consumes twice the amount
of memory as the existing `k_smallest` algorithm but achieves linear time in the
number of elements in the input.

[1] https://quickwit.io/blog/top-k-complexity
  • Loading branch information
adamreichold authored and Philippe-Cholet committed Jul 3, 2024
1 parent 1c850ce commit 31c7fd8
Show file tree
Hide file tree
Showing 3 changed files with 270 additions and 2 deletions.
39 changes: 39 additions & 0 deletions src/k_smallest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,45 @@ where
storage
}

pub(crate) fn k_smallest_relaxed_general<I, F>(iter: I, k: usize, mut comparator: F) -> Vec<I::Item>
where
I: Iterator,
F: FnMut(&I::Item, &I::Item) -> Ordering,
{
if k == 0 {
iter.last();
return Vec::new();
}

let mut iter = iter.fuse();
let mut buf = iter.by_ref().take(2 * k).collect::<Vec<_>>();

if buf.len() < k {
buf.sort_unstable_by(&mut comparator);
return buf;
}

buf.select_nth_unstable_by(k - 1, &mut comparator);
buf.truncate(k);

iter.for_each(|val| {
if comparator(&val, &buf[k - 1]) != Ordering::Less {
return;
}

buf.push(val);

if buf.len() == 2 * k {
buf.select_nth_unstable_by(k - 1, &mut comparator);
buf.truncate(k);
}
});

buf.sort_unstable_by(&mut comparator);
buf.truncate(k);
buf
}

#[inline]
pub(crate) fn key_to_cmp<T, K, F>(mut key: F) -> impl FnMut(&T, &T) -> Ordering
where
Expand Down
187 changes: 187 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3153,6 +3153,105 @@ pub trait Itertools: Iterator {
self.k_smallest_by(k, k_smallest::key_to_cmp(key))
}

/// Sort the k smallest elements into a new iterator, in ascending order, relaxing the amount of memory required.
///
/// **Note:** This consumes the entire iterator, and returns the result
/// as a new iterator that owns its elements. If the input contains
/// less than k elements, the result is equivalent to `self.sorted()`.
///
/// This is guaranteed to use `2 * k * sizeof(Self::Item) + O(1)` memory
/// and `O(n + k log k)` time, with `n` the number of elements in the input,
/// meaning it uses more memory than the minimum obtained by [`k_smallest`](Itertools::k_smallest)
/// but achieves linear time in the number of elements.
///
/// The sorted iterator, if directly collected to a `Vec`, is converted
/// without any extra copying or allocation cost.
///
/// **Note:** This is functionally-equivalent to `self.sorted().take(k)`
/// but much more efficient.
///
/// ```
/// use itertools::Itertools;
///
/// // A random permutation of 0..15
/// let numbers = vec![6, 9, 1, 14, 0, 4, 8, 7, 11, 2, 10, 3, 13, 12, 5];
///
/// let five_smallest = numbers
/// .into_iter()
/// .k_smallest_relaxed(5);
///
/// itertools::assert_equal(five_smallest, 0..5);
/// ```
#[cfg(feature = "use_alloc")]
fn k_smallest_relaxed(self, k: usize) -> VecIntoIter<Self::Item>
where
Self: Sized,
Self::Item: Ord,
{
self.k_smallest_relaxed_by(k, Ord::cmp)
}

/// Sort the k smallest elements into a new iterator using the provided comparison, relaxing the amount of memory required.
///
/// The sorted iterator, if directly collected to a `Vec`, is converted
/// without any extra copying or allocation cost.
///
/// This corresponds to `self.sorted_by(cmp).take(k)` in the same way that
/// [`k_smallest_relaxed`](Itertools::k_smallest_relaxed) corresponds to `self.sorted().take(k)`,
/// in both semantics and complexity.
///
/// ```
/// use itertools::Itertools;
///
/// // A random permutation of 0..15
/// let numbers = vec![6, 9, 1, 14, 0, 4, 8, 7, 11, 2, 10, 3, 13, 12, 5];
///
/// let five_smallest = numbers
/// .into_iter()
/// .k_smallest_relaxed_by(5, |a, b| (a % 7).cmp(&(b % 7)).then(a.cmp(b)));
///
/// itertools::assert_equal(five_smallest, vec![0, 7, 14, 1, 8]);
/// ```
#[cfg(feature = "use_alloc")]
fn k_smallest_relaxed_by<F>(self, k: usize, cmp: F) -> VecIntoIter<Self::Item>
where
Self: Sized,
F: FnMut(&Self::Item, &Self::Item) -> Ordering,
{
k_smallest::k_smallest_relaxed_general(self, k, cmp).into_iter()
}

/// Return the elements producing the k smallest outputs of the provided function, relaxing the amount of memory required.
///
/// The sorted iterator, if directly collected to a `Vec`, is converted
/// without any extra copying or allocation cost.
///
/// This corresponds to `self.sorted_by_key(key).take(k)` in the same way that
/// [`k_smallest_relaxed`](Itertools::k_smallest_relaxed) corresponds to `self.sorted().take(k)`,
/// in both semantics and complexity.
///
/// ```
/// use itertools::Itertools;
///
/// // A random permutation of 0..15
/// let numbers = vec![6, 9, 1, 14, 0, 4, 8, 7, 11, 2, 10, 3, 13, 12, 5];
///
/// let five_smallest = numbers
/// .into_iter()
/// .k_smallest_relaxed_by_key(5, |n| (n % 7, *n));
///
/// itertools::assert_equal(five_smallest, vec![0, 7, 14, 1, 8]);
/// ```
#[cfg(feature = "use_alloc")]
fn k_smallest_relaxed_by_key<F, K>(self, k: usize, key: F) -> VecIntoIter<Self::Item>
where
Self: Sized,
F: FnMut(&Self::Item) -> K,
K: Ord,
{
self.k_smallest_relaxed_by(k, k_smallest::key_to_cmp(key))
}

/// Sort the k largest elements into a new iterator, in descending order.
///
/// The sorted iterator, if directly collected to a `Vec`, is converted
Expand Down Expand Up @@ -3243,6 +3342,94 @@ pub trait Itertools: Iterator {
self.k_largest_by(k, k_smallest::key_to_cmp(key))
}

/// Sort the k largest elements into a new iterator, in descending order, relaxing the amount of memory required.
///
/// The sorted iterator, if directly collected to a `Vec`, is converted
/// without any extra copying or allocation cost.
///
/// It is semantically equivalent to [`k_smallest_relaxed`](Itertools::k_smallest_relaxed)
/// with a reversed `Ord`.
///
/// ```
/// use itertools::Itertools;
///
/// // A random permutation of 0..15
/// let numbers = vec![6, 9, 1, 14, 0, 4, 8, 7, 11, 2, 10, 3, 13, 12, 5];
///
/// let five_largest = numbers
/// .into_iter()
/// .k_largest_relaxed(5);
///
/// itertools::assert_equal(five_largest, vec![14, 13, 12, 11, 10]);
/// ```
#[cfg(feature = "use_alloc")]
fn k_largest_relaxed(self, k: usize) -> VecIntoIter<Self::Item>
where
Self: Sized,
Self::Item: Ord,
{
self.k_largest_relaxed_by(k, Self::Item::cmp)
}

/// Sort the k largest elements into a new iterator using the provided comparison, relaxing the amount of memory required.
///
/// The sorted iterator, if directly collected to a `Vec`, is converted
/// without any extra copying or allocation cost.
///
/// Functionally equivalent to [`k_smallest_relaxed_by`](Itertools::k_smallest_relaxed_by)
/// with a reversed `Ord`.
///
/// ```
/// use itertools::Itertools;
///
/// // A random permutation of 0..15
/// let numbers = vec![6, 9, 1, 14, 0, 4, 8, 7, 11, 2, 10, 3, 13, 12, 5];
///
/// let five_largest = numbers
/// .into_iter()
/// .k_largest_relaxed_by(5, |a, b| (a % 7).cmp(&(b % 7)).then(a.cmp(b)));
///
/// itertools::assert_equal(five_largest, vec![13, 6, 12, 5, 11]);
/// ```
#[cfg(feature = "use_alloc")]
fn k_largest_relaxed_by<F>(self, k: usize, mut cmp: F) -> VecIntoIter<Self::Item>
where
Self: Sized,
F: FnMut(&Self::Item, &Self::Item) -> Ordering,
{
self.k_smallest_relaxed_by(k, move |a, b| cmp(b, a))
}

/// Return the elements producing the k largest outputs of the provided function, relaxing the amount of memory required.
///
/// The sorted iterator, if directly collected to a `Vec`, is converted
/// without any extra copying or allocation cost.
///
/// Functionally equivalent to [`k_smallest_relaxed_by_key`](Itertools::k_smallest_relaxed_by_key)
/// with a reversed `Ord`.
///
/// ```
/// use itertools::Itertools;
///
/// // A random permutation of 0..15
/// let numbers = vec![6, 9, 1, 14, 0, 4, 8, 7, 11, 2, 10, 3, 13, 12, 5];
///
/// let five_largest = numbers
/// .into_iter()
/// .k_largest_relaxed_by_key(5, |n| (n % 7, *n));
///
/// itertools::assert_equal(five_largest, vec![13, 6, 12, 5, 11]);
/// ```
#[cfg(feature = "use_alloc")]
fn k_largest_relaxed_by_key<F, K>(self, k: usize, key: F) -> VecIntoIter<Self::Item>
where
Self: Sized,
F: FnMut(&Self::Item) -> K,
K: Ord,
{
self.k_largest_relaxed_by(k, k_smallest::key_to_cmp(key))
}

/// Consumes the iterator and return an iterator of the last `n` elements.
///
/// The iterator, if directly collected to a `VecDeque`, is converted
Expand Down
46 changes: 44 additions & 2 deletions tests/test_std.rs
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,42 @@ qc::quickcheck! {
it::assert_equal(largest_by, sorted_largest.clone());
it::assert_equal(largest_by_key, sorted_largest);
}

fn k_smallest_relaxed_range(n: i64, m: u16, k: u16) -> () {
// u16 is used to constrain k and m to 0..2¹⁶,
// otherwise the test could use too much memory.
let (k, m) = (k as usize, m as u64);

let mut v: Vec<_> = (n..n.saturating_add(m as _)).collect();
// Generate a random permutation of n..n+m
v.shuffle(&mut thread_rng());

// Construct the right answers for the top and bottom elements
let mut sorted = v.clone();
sorted.sort();
// how many elements are we checking
let num_elements = min(k, m as _);

// Compute the top and bottom k in various combinations
let sorted_smallest = sorted[..num_elements].iter().cloned();
let smallest = v.iter().cloned().k_smallest_relaxed(k);
let smallest_by = v.iter().cloned().k_smallest_relaxed_by(k, Ord::cmp);
let smallest_by_key = v.iter().cloned().k_smallest_relaxed_by_key(k, |&x| x);

let sorted_largest = sorted[sorted.len() - num_elements..].iter().rev().cloned();
let largest = v.iter().cloned().k_largest_relaxed(k);
let largest_by = v.iter().cloned().k_largest_relaxed_by(k, Ord::cmp);
let largest_by_key = v.iter().cloned().k_largest_relaxed_by_key(k, |&x| x);

// Check the variations produce the same answers and that they're right
it::assert_equal(smallest, sorted_smallest.clone());
it::assert_equal(smallest_by, sorted_smallest.clone());
it::assert_equal(smallest_by_key, sorted_smallest);

it::assert_equal(largest, sorted_largest.clone());
it::assert_equal(largest_by, sorted_largest.clone());
it::assert_equal(largest_by_key, sorted_largest);
}
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -572,8 +608,11 @@ where
I::Item: Ord + Debug,
{
let j = i.clone();
let i1 = i.clone();
let j1 = i.clone();
let k = k as usize;
it::assert_equal(i.k_smallest(k), j.sorted().take(k))
it::assert_equal(i.k_smallest(k), j.sorted().take(k));
it::assert_equal(i1.k_smallest_relaxed(k), j1.sorted().take(k));
}

// Similar to `k_smallest_sort` but for our custom heap implementation.
Expand All @@ -583,8 +622,11 @@ where
I::Item: Ord + Debug,
{
let j = i.clone();
let i1 = i.clone();
let j1 = i.clone();
let k = k as usize;
it::assert_equal(i.k_smallest_by(k, Ord::cmp), j.sorted().take(k))
it::assert_equal(i.k_smallest_by(k, Ord::cmp), j.sorted().take(k));
it::assert_equal(i1.k_smallest_relaxed_by(k, Ord::cmp), j1.sorted().take(k));
}

macro_rules! generic_test {
Expand Down

0 comments on commit 31c7fd8

Please sign in to comment.