diff --git a/src/modint.rs b/src/modint.rs index 5779f27..e172fc6 100644 --- a/src/modint.rs +++ b/src/modint.rs @@ -673,7 +673,7 @@ pub trait ModIntBase: #[inline] fn pow(self, mut n: u64) -> Self { let mut x = self; - let mut r = Self::raw(1); + let mut r = Self::raw((Self::modulus() > 1) as u32); while n > 0 { if n & 1 == 1 { r *= x; @@ -1043,9 +1043,9 @@ macro_rules! impl_folding { impl_folding! { impl Sum<_> for StaticModInt { fn sum(_) -> _ { _(Self::raw(0), Add::add) } } - impl Product<_> for StaticModInt { fn product(_) -> _ { _(Self::raw(1), Mul::mul) } } + impl Product<_> for StaticModInt { fn product(_) -> _ { _(Self::raw((Self::modulus() > 1) as u32), Mul::mul) } } impl Sum<_> for DynamicModInt { fn sum(_) -> _ { _(Self::raw(0), Add::add) } } - impl Product<_> for DynamicModInt { fn product(_) -> _ { _(Self::raw(1), Mul::mul) } } + impl Product<_> for DynamicModInt { fn product(_) -> _ { _(Self::raw((Self::modulus() > 1) as u32), Mul::mul) } } } #[cfg(test)] @@ -1160,7 +1160,21 @@ mod tests { assert_eq!(expected, c); } - // test `2^31 < modulus < 2^32` case. + // Corner cases of "modint" when mod = 1 + // https://github.com/rust-lang-ja/ac-library-rs/issues/110 + #[test] + fn mod1_corner_case() { + ModInt::set_modulus(1); // !! + + let x: ModInt = std::iter::empty::().product(); + assert_eq!(x.val(), 0); + + let y = ModInt::new(123).pow(0); + assert_eq!(y.val(), 0); + } + + // test `2^31 < modulus < 2^32` case + // https://github.com/rust-lang-ja/ac-library-rs/issues/111 #[test] fn dynamic_modint_m32() { let m = 3221225471;