Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: make interface of NCRModuloP fail-safe #2469

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 128 additions & 83 deletions math/ncr_modulo_p.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
*/

#include <cassert> /// for assert
#include <iostream> /// for io operations
#include <iostream> /// for std::cout
#include <vector> /// for std::vector

/**
Expand All @@ -25,71 +25,95 @@ namespace math {
* implementation.
*/
namespace ncr_modulo_p {

/**
* @brief Class which contains all methods required for calculating nCr mod p
* @namespace utils
* @brief this namespace contains the definitions of the functions called from
* the class math::ncr_modulo_p::NCRModuloP
*/
class NCRModuloP {
private:
std::vector<uint64_t> fac{}; /// stores precomputed factorial(i) % p value
uint64_t p = 0; /// the p from (nCr % p)

public:
/** Constructor which precomputes the values of n! % mod from n=0 to size
* and stores them in vector 'fac'
* @params[in] the numbers 'size', 'mod'
*/
NCRModuloP(const uint64_t& size, const uint64_t& mod) {
p = mod;
fac = std::vector<uint64_t>(size);
fac[0] = 1;
for (int i = 1; i <= size; i++) {
fac[i] = (fac[i - 1] * i) % p;
}
namespace utils {
vil02 marked this conversation as resolved.
Show resolved Hide resolved
/**
* @brief finds the values x and y such that a*x + b*y = gcd(a,b)
*
* @param[in] a the first input of the gcd
* @param[in] a the second input of the gcd
* @param[out] x the Bézout coefficient of a
* @param[out] y the Bézout coefficient of b
* @return the gcd of a and b
*/
int64_t gcdExtended(const int64_t& a, const int64_t& b, int64_t& x,
vil02 marked this conversation as resolved.
Show resolved Hide resolved
int64_t& y) {
if (a == 0) {
x = 0;
y = 1;
return b;
}

/** Finds the value of x, y such that a*x + b*y = gcd(a,b)
*
* @params[in] the numbers 'a', 'b' and address of 'x' and 'y' from above
* equation
* @returns the gcd of a and b
*/
uint64_t gcdExtended(const uint64_t& a, const uint64_t& b, int64_t* x,
int64_t* y) {
if (a == 0) {
*x = 0, *y = 1;
return b;
}
int64_t x1 = 0, y1 = 0;
const int64_t gcd = gcdExtended(b % a, a, x1, y1);

int64_t x1 = 0, y1 = 0;
uint64_t gcd = gcdExtended(b % a, a, &x1, &y1);
x = y1 - (b / a) * x1;
y = x1;
return gcd;
}

*x = y1 - (b / a) * x1;
*y = x1;
return gcd;
/** Find modular inverse of a modulo m i.e. a number x such that (a*x)%m = 1
*
* @param[in] a the number for which the modular inverse is queried
* @param[in] m the modulus
* @return the inverce of a modulo m, if it exists, -1 otherwise
*/
int64_t modInverse(const int64_t& a, const int64_t& m) {
int64_t x = 0, y = 0;
const int64_t g = gcdExtended(a, m, x, y);
if (g != 1) { // modular inverse doesn't exist
return -1;
} else {
return ((x + m) % m);
}
}
} // namespace utils
/**
* @brief Class which contains all methods required for calculating nCr mod p
*/
class NCRModuloP {
private:
const int64_t p = 0; /// the p from (nCr % p)
const std::vector<int64_t>
fac; /// stores precomputed factorial(i) % p value

/** Find modular inverse of a with m i.e. a number x such that (a*x)%m = 1
*
* @params[in] the numbers 'a' and 'm' from above equation
* @returns the modular inverse of a
/**
* @brief computes the array of values of factorials reduced modulo mod
* @param max_arg_val argument of the last factorial stored in the result
* @param mod value of the divisor used to reduce factorials
* @return vector storing factorials of the numbers 0, ..., max_arg_val
* reduced modulo mod
*/
int64_t modInverse(const uint64_t& a, const uint64_t& m) {
int64_t x = 0, y = 0;
uint64_t g = gcdExtended(a, m, &x, &y);
if (g != 1) { // modular inverse doesn't exist
return -1;
} else {
int64_t res = ((x + m) % m);
return res;
static std::vector<int64_t> computeFactorialsMod(const int64_t& max_arg_val,
const int64_t& mod) {
auto res = std::vector<int64_t>(max_arg_val + 1);
res[0] = 1;
for (int64_t i = 1; i <= max_arg_val; i++) {
res[i] = (res[i - 1] * i) % mod;
}
return res;
}

/** Find nCr % p
*
* @params[in] the numbers 'n', 'r' and 'p'
* @returns the value nCr % p
public:
/**
* @brief constructs an NCRModuloP object allowing to compute (nCr)%p for
* inputs from 0 to size
*/
int64_t ncr(const uint64_t& n, const uint64_t& r, const uint64_t& p) {
NCRModuloP(const int64_t& size, const int64_t& p)
: p(p), fac(computeFactorialsMod(size, p)) {}

/**
* @brief computes nCr % p
* @param[in] n the number of objects to be chosen
* @param[in] r the number of objects to choose from
* @return the value nCr % p
*/
int64_t ncr(const int64_t& n, const int64_t& r) const {
// Base cases
if (r > n) {
return 0;
Expand All @@ -101,50 +125,71 @@ class NCRModuloP {
return 1;
}
// fac is a global array with fac[r] = (r! % p)
int64_t denominator = modInverse(fac[r], p);
if (denominator < 0) { // modular inverse doesn't exist
return -1;
}
denominator = (denominator * modInverse(fac[n - r], p)) % p;
if (denominator < 0) { // modular inverse doesn't exist
const auto denominator = (fac[r] * fac[n - r]) % p;
const auto denominator_inv = utils::modInverse(denominator, p);
if (denominator_inv < 0) { // modular inverse doesn't exist
return -1;
}
return (fac[n] * denominator) % p;
return (fac[n] * denominator_inv) % p;
}
};
} // namespace ncr_modulo_p
} // namespace math

/**
* @brief Test implementations
* @param ncrObj object which contains the precomputed factorial values and
* ncr function
* @returns void
* @brief tests math::ncr_modulo_p::NCRModuloP
vil02 marked this conversation as resolved.
Show resolved Hide resolved
*/
static void tests(math::ncr_modulo_p::NCRModuloP ncrObj) {
// (52323 C 26161) % (1e9 + 7) = 224944353
assert(ncrObj.ncr(52323, 26161, 1000000007) == 224944353);
// 6 C 2 = 30, 30%5 = 0
assert(ncrObj.ncr(6, 2, 5) == 0);
vil02 marked this conversation as resolved.
Show resolved Hide resolved
// 7C3 = 35, 35 % 29 = 8
assert(ncrObj.ncr(7, 3, 29) == 6);
static void tests() {
vil02 marked this conversation as resolved.
Show resolved Hide resolved
struct TestCase {
const int64_t size;
const int64_t p;
const int64_t n;
const int64_t r;
const int64_t expected;

TestCase(const int64_t size, const int64_t p, const int64_t n,
const int64_t r, const int64_t expected)
: size(size), p(p), n(n), r(r), expected(expected) {}
};
const std::vector<TestCase> test_cases = {
TestCase(60000, 1000000007, 52323, 26161, 224944353),
TestCase(20, 5, 6, 2, 30 % 5),
TestCase(100, 29, 7, 3, 35 % 29),
TestCase(1000, 13, 10, 3, 120 % 13),
TestCase(20, 17, 1, 10, 0),
TestCase(45, 19, 23, 1, 23 % 19),
TestCase(45, 19, 23, 0, 1),
TestCase(45, 19, 23, 23, 1),
TestCase(20, 9, 10, 2, -1)};
for (const auto& tc : test_cases) {
assert(math::ncr_modulo_p::NCRModuloP(tc.size, tc.p).ncr(tc.n, tc.r) ==
tc.expected);
}
vil02 marked this conversation as resolved.
Show resolved Hide resolved

std::cout << "\n\nAll tests have successfully passed!\n";
}

/**
* @brief Main function
* @returns 0 on exit
* @brief example showing the usage of the math::ncr_modulo_p::NCRModuloP class
*/
int main() {
// populate the fac array
const uint64_t size = 1e6 + 1;
const uint64_t p = 1e9 + 7;
math::ncr_modulo_p::NCRModuloP ncrObj =
math::ncr_modulo_p::NCRModuloP(size, p);
// test 6Ci for i=0 to 7
void example() {
const int64_t size = 1e6 + 1;
const int64_t p = 1e9 + 7;

// the ncrObj contains the precomputed values of factorials modulo p for
// values from 0 to size
const auto ncrObj = math::ncr_modulo_p::NCRModuloP(size, p);

// having the ncrObj we can efficiently query the values of (n C r)%p
// note that time of the computation does not depend on size
for (int i = 0; i <= 7; i++) {
std::cout << 6 << "C" << i << " = " << ncrObj.ncr(6, i, p) << "\n";
std::cout << 6 << "C" << i << " mod " << p << " = " << ncrObj.ncr(6, i)
<< "\n";
}
tests(ncrObj); // execute the tests
std::cout << "Assertions passed\n";
}

vil02 marked this conversation as resolved.
Show resolved Hide resolved
int main() {
tests();
example();
return 0;
}
Loading