From a7c4ff84d0cc8bda41d8f693b6304bb247b6fe9e Mon Sep 17 00:00:00 2001 From: Mihir Wadekar Date: Wed, 8 Nov 2023 17:45:44 +0530 Subject: [PATCH] feat: Adds Modexp function for modular exponentiation support in the library --- foundry.toml | 1 + hufftest.sh | 0 src/Math.huff | 45 ++++++++++++++++++++++++++++++++- src/interfaces/IMath.sol | 2 ++ src/wrappers/MathWrapper.huff | 14 +++++++++- test/foundry/Math.t.sol | 25 ++++++++++++++++++ test/foundry/MathForkTest.t.sol | 11 ++++++++ test/huff/Math.t.huff | 21 +++++++++++++++ 8 files changed, 117 insertions(+), 2 deletions(-) mode change 100644 => 100755 hufftest.sh diff --git a/foundry.toml b/foundry.toml index 9732064..5aa798b 100644 --- a/foundry.toml +++ b/foundry.toml @@ -2,5 +2,6 @@ src = 'src' out = 'artifacts' libs = ["node_modules", "lib"] +evm_version = "shanghai" ffi=true # See more config options https://github.com/foundry-rs/foundry/tree/master/config \ No newline at end of file diff --git a/hufftest.sh b/hufftest.sh old mode 100644 new mode 100755 diff --git a/src/Math.huff b/src/Math.huff index 2b690e9..315e3f8 100644 --- a/src/Math.huff +++ b/src/Math.huff @@ -42,4 +42,47 @@ complete: } - \ No newline at end of file +#define macro MOD_EXP() = takes (3) returns (1) { + // Stack input: [base, exponent, modulus] + + // Check if modulus is zero, revert if it is + dup1 // Duplicate the modulus at the top of the stack + 0x00 eq // Compare it to zero + jumpi($MODULUS_NOT_ZERO) // Jump if not equal to zero + 0x00 0x00 revert // If zero, revert + + // Label for non-zero modulus + $MODULUS_NOT_ZERO: + + // Prepare memory for staticcall to precompiled contract at 0x05 + // First, we need to store the input data (base, exponent, modulus) in memory. + // Assuming input is in the correct order on the stack: [base, exponent, modulus] + 0x20 0x00 mstore // Store base at memory 0x20 + 0x40 0x20 mstore // Store exponent at memory 0x40 + 0x60 0x40 mstore // Store modulus at memory 0x60 + + // Prepare input data for the precompiled contract + // Input format: + 0x20 0x00 mstore // Store length of base (0x20 bytes) at memory 0x00 + 0x20 0x20 mload // Load base from memory 0x20 and store at memory 0x20 + 0x20 0x40 mstore // Store length of exponent (0x20 bytes) at memory 0x40 + 0x20 0x60 mload // Load exponent from memory 0x40 and store at memory 0x60 + 0x20 0x80 mstore // Store length of modulus (0x20 bytes) at memory 0x80 + 0x20 0xa0 mload // Load modulus from memory 0x60 and store at memory 0xa0 + + // Set up the call to the precompiled contract + 0x00 0x00 0x00 0x05 gas // Address of the precompiled contract (0x05) and gas for the call + 0x00 0xc0 // Start of input data in memory and input data size (192 bytes) + 0x00 0x20 // Start of output data in memory and output data size (32 bytes) + staticcall + + // Check if the call was successful + iszero // If the call was not successful, the top of the stack will be 0 + jumpi($CALL_SUCCESSFUL) // Jump if successful + 0x00 0x00 revert // If not successful, revert + + // Label for call success + $CALL_SUCCESSFUL: + 0x00 0x20 mload // Load the result from memory location 0x00 + // The result of base^exponent % modulus is now on top of the stack +} diff --git a/src/interfaces/IMath.sol b/src/interfaces/IMath.sol index 9621d2e..054dab3 100644 --- a/src/interfaces/IMath.sol +++ b/src/interfaces/IMath.sol @@ -12,4 +12,6 @@ interface IMath { function divideNumbers(uint256, uint256) external view returns (uint256); function abs(uint256, uint256) external view returns (uint256); + + function modExp(uint256, uint256, uint256) external view returns (uint256); } diff --git a/src/wrappers/MathWrapper.huff b/src/wrappers/MathWrapper.huff index 3234522..04b9e68 100644 --- a/src/wrappers/MathWrapper.huff +++ b/src/wrappers/MathWrapper.huff @@ -6,6 +6,7 @@ #define function multiplyNumbers(uint256,uint256) nonpayable returns (uint256) #define function divideNumbers(uint256,uint256) nonpayable returns (uint256) #define function abs(uint256,uint256) nonpayable returns (uint256) +#define function modExp(uint256,uint256,uint256) nonpayable returns (uint256) #define macro ADD_WRAPPER() = takes (2) returns (1) { 0x04 calldataload // [num1] @@ -47,7 +48,14 @@ 0x20 0x00 return // [] } - + #define macro MODEXP_WRAPPER() = takes (3) returns (1) { + 0x04 calldataload // [base] + 0x24 calldataload // [exponent, base] + 0x44 calldataload // [modulus, exponent, base] + MODEXP() // [base^exponent mod modulus] + 0x00 mstore // [] + 0x20 0x00 return // [] + } #define macro MAIN() = takes (0) returns (0) { @@ -63,6 +71,7 @@ dup1 0xd3f3cd7b eq multiplyNumbers jumpi dup1 0x8fce12ed eq divideNumbers jumpi dup1 0xe093a157 eq abs jumpi + dup1 0x3148f14f eq modExp jumpi addNumbers: @@ -76,6 +85,9 @@ divideNumbers: DIVIDE_WRAPPER() + + modExp: + MODEXP_WRAPPER() abs: ABS_WRAPPER() diff --git a/test/foundry/Math.t.sol b/test/foundry/Math.t.sol index 60b4670..896747f 100644 --- a/test/foundry/Math.t.sol +++ b/test/foundry/Math.t.sol @@ -93,4 +93,29 @@ contract MathTest is Test { uint256 _result = a > b ? a - b : b - a; require(math.abs(a, b) == _result); } + + function testModExp() public { + // Example test: 2^3 % 5 should equal 3 + uint256 base = 10; + uint256 exponent = 3; + uint256 modulus = 13; + uint256 expected = 12; + + uint256 result = math.modExp(base, exponent, modulus); + assertEq(result, expected, "modExp did not return the expected value"); + } + + function testModExp_fuzz(uint256 b, uint256 e, uint256 m) public { + // To avoid testing with modulus zero, which would revert + vm.assume(m > 1); + // To avoid gas issues, cap the exponent + uint256 exponent = e % 256; + + // The actual modExp calculation can be complicated to emulate in Solidity due to gas constraints, + // so here we just test that the function does not revert and returns a value + // less than the modulus. + uint256 result = math.modExp(b, exponent, m); + + assertLt(result, m, "modExp result should be less than the modulus"); + } } diff --git a/test/foundry/MathForkTest.t.sol b/test/foundry/MathForkTest.t.sol index 7ecfcec..4a651d1 100644 --- a/test/foundry/MathForkTest.t.sol +++ b/test/foundry/MathForkTest.t.sol @@ -36,4 +36,15 @@ contract MathForkTest is Test { function testAbs() public view { require(math.abs(1, 10) == 9); } + + function testModExp() public { + // Example test: 2^3 % 5 should equal 3 + uint256 base = 2; + uint256 exponent = 3; + uint256 modulus = 5; + uint256 expected = 3; + + uint256 result = math.modExp(base, exponent, modulus); + assertEq(result, expected, "modExp did not return the expected value"); + } } diff --git a/test/huff/Math.t.huff b/test/huff/Math.t.huff index 66e031c..6430733 100644 --- a/test/huff/Math.t.huff +++ b/test/huff/Math.t.huff @@ -75,3 +75,24 @@ ASSERT_EQ() // [4e18==result] } +#define test TEST_MODEXP() = { + // Test case 1: Simple modular exponentiation + // Using small numbers for easy manual verification: 2^3 % 5 = 3 + 0x05 // [modulus = 5] + 0x03 // [exponent = 3, modulus] + 0x02 // [base = 2, exponent, modulus] + MODEXP() // [result] + 0x03 // [expected = 3, result] + ASSERT_EQ() // [3 == result] + + // Test case 2: Larger numbers + // We need to choose numbers such that we can calculate the expected result manually or with a tool + // For example: (0x04)^2 % 0x05 = 0x04 + 0x05 // [modulus = 5] + 0x02 // [exponent = 2, modulus] + 0x04 // [base = 4, exponent, modulus] + MODEXP() // [result] + 0x04 // [expected = 4, result] + ASSERT_EQ() // [4 == result] +} +