diff --git a/src/combinations.rs b/src/combinations.rs index 68a59c5e4..59654a1c0 100644 --- a/src/combinations.rs +++ b/src/combinations.rs @@ -2,6 +2,7 @@ use std::fmt; use std::iter::FusedIterator; use super::lazy_buffer::LazyBuffer; +use super::size_hint::{self, SizeHint}; use alloc::vec::Vec; /// An iterator to iterate through all the `k`-length combinations in an iterator. @@ -120,9 +121,34 @@ impl Iterator for Combinations // Create result vector based on the indices Some(self.indices.iter().map(|i| self.pool[*i].clone()).collect()) } + + fn size_hint(&self) -> SizeHint { + let k = self.k(); + size_hint::try_map(self.pool.size_hint(), |n| { + if self.first { + binomial(n, k) + } else { + self.indices + .iter() + .enumerate() + .fold(Some(0), |sum, (k0, n0)| { + sum.and_then(|s| s.checked_add(binomial(n - 1 - *n0, k - k0)?)) + }) + } + }) + } } impl FusedIterator for Combinations where I: Iterator, I::Item: Clone {} + +pub(crate) fn binomial(n: usize, k: usize) -> Option { + if n < k { + return Some(0); + } + // n! / (n - k)! / k! but trying to avoid it overflows: + let k = (n - k).min(k); + (1..=k).fold(Some(1), |res, i| res.and_then(|x| x.checked_mul(n - i + 1).map(|x| x / i))) +} diff --git a/src/combinations_with_replacement.rs b/src/combinations_with_replacement.rs index 9b62f8ce8..610960ffd 100644 --- a/src/combinations_with_replacement.rs +++ b/src/combinations_with_replacement.rs @@ -2,6 +2,7 @@ use alloc::vec::Vec; use std::fmt; use std::iter::FusedIterator; +use super::combinations::binomial; use super::lazy_buffer::LazyBuffer; use super::size_hint::{self, SizeHint}; @@ -103,14 +104,6 @@ where } fn size_hint(&self) -> SizeHint { - fn binomial(n: usize, k: usize) -> Option { - if n < k { - return Some(0); - } - // n! / (n - k)! / k! but trying to avoid it overflows: - let k = (n - k).min(k); - (1..=k).fold(Some(1), |res, i| res.and_then(|x| x.checked_mul(n - i + 1).map(|x| x / i))) - } let k_perms = |n: usize, k: usize| binomial((n + k).saturating_sub(1), k); let k = self.indices.len(); size_hint::try_map(self.pool.size_hint(), |n| { diff --git a/src/size_hint.rs b/src/size_hint.rs index 008a0a289..920964cac 100644 --- a/src/size_hint.rs +++ b/src/size_hint.rs @@ -119,6 +119,8 @@ pub fn min(a: SizeHint, b: SizeHint) -> SizeHint { } /// Try to apply a function `f` on both bounds of a `SizeHint`, failure means overflow. +/// +/// For the resulting size hint to be correct, `f` must be increasing. #[inline] pub fn try_map(sh: SizeHint, mut f: F) -> SizeHint where diff --git a/tests/test_std.rs b/tests/test_std.rs index 500175f9d..f76ee45b6 100644 --- a/tests/test_std.rs +++ b/tests/test_std.rs @@ -909,6 +909,20 @@ fn combinations_zero() { it::assert_equal((0..0).combinations(0), vec![vec![]]); } +#[test] +fn combinations_range_size_hint() { + for n in 0..6 { + for k in 0..=n { + let len = (n - k + 1..=n).product::() / (1..=k).product::(); + let mut it = (0..n).combinations(k); + for count in (0..=len).rev() { + assert_eq!(it.size_hint(), (count, Some(count))); + assert_eq!(it.next().is_none(), count == 0); + } + } + } +} + #[test] fn permutations_zero() { it::assert_equal((1..3).permutations(0), vec![vec![]]);