Skip to content

Commit

Permalink
size_hint for combinations
Browse files Browse the repository at this point in the history
  • Loading branch information
Philippe-Cholet committed Aug 13, 2023
1 parent 2524a9c commit 06d5095
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 8 deletions.
26 changes: 26 additions & 0 deletions src/combinations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::fmt;
use std::iter::FusedIterator;

use super::lazy_buffer::LazyBuffer;
use super::size_hint::{self, SizeHint};
use alloc::vec::Vec;

/// An iterator to iterate through all the `k`-length combinations in an iterator.
Expand Down Expand Up @@ -120,9 +121,34 @@ impl<I> Iterator for Combinations<I>
// Create result vector based on the indices
Some(self.indices.iter().map(|i| self.pool[*i].clone()).collect())
}

fn size_hint(&self) -> SizeHint {
let k = self.k();
size_hint::try_map(self.pool.size_hint(), |n| {
if self.first {
binomial(n, k)
} else {
self.indices
.iter()
.enumerate()
.fold(Some(0), |sum, (k0, n0)| {
sum.and_then(|s| s.checked_add(binomial(n - 1 - *n0, k - k0)?))
})
}
})
}
}

impl<I> FusedIterator for Combinations<I>
where I: Iterator,
I::Item: Clone
{}

pub(crate) fn binomial(n: usize, k: usize) -> Option<usize> {
if n < k {
return Some(0);
}
// n! / (n - k)! / k! but trying to avoid it overflows:
let k = (n - k).min(k);
(1..=k).fold(Some(1), |res, i| res.and_then(|x| x.checked_mul(n - i + 1).map(|x| x / i)))
}
9 changes: 1 addition & 8 deletions src/combinations_with_replacement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use alloc::vec::Vec;
use std::fmt;
use std::iter::FusedIterator;

use super::combinations::binomial;
use super::lazy_buffer::LazyBuffer;
use super::size_hint::{self, SizeHint};

Expand Down Expand Up @@ -103,14 +104,6 @@ where
}

fn size_hint(&self) -> SizeHint {
fn binomial(n: usize, k: usize) -> Option<usize> {
if n < k {
return Some(0);
}
// n! / (n - k)! / k! but trying to avoid it overflows:
let k = (n - k).min(k);
(1..=k).fold(Some(1), |res, i| res.and_then(|x| x.checked_mul(n - i + 1).map(|x| x / i)))
}
let k_perms = |n: usize, k: usize| binomial((n + k).saturating_sub(1), k);
let k = self.indices.len();
size_hint::try_map(self.pool.size_hint(), |n| {
Expand Down
2 changes: 2 additions & 0 deletions src/size_hint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ pub fn min(a: SizeHint, b: SizeHint) -> SizeHint {
}

/// Try to apply a function `f` on both bounds of a `SizeHint`, failure means overflow.
///
/// For the resulting size hint to be correct, `f` must be increasing.
#[inline]
pub fn try_map<F>(sh: SizeHint, mut f: F) -> SizeHint
where
Expand Down
14 changes: 14 additions & 0 deletions tests/test_std.rs
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,20 @@ fn combinations_zero() {
it::assert_equal((0..0).combinations(0), vec![vec![]]);
}

#[test]
fn combinations_range_size_hint() {
for n in 0..6 {
for k in 0..=n {
let len = (n - k + 1..=n).product::<usize>() / (1..=k).product::<usize>();
let mut it = (0..n).combinations(k);
for count in (0..=len).rev() {
assert_eq!(it.size_hint(), (count, Some(count)));
assert_eq!(it.next().is_none(), count == 0);
}
}
}
}

#[test]
fn permutations_zero() {
it::assert_equal((1..3).permutations(0), vec![vec![]]);
Expand Down

0 comments on commit 06d5095

Please sign in to comment.