diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..d79c399 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "examples/eth/lib/forge-std"] + path = examples/eth/lib/forge-std + url = https://github.com/foundry-rs/forge-std diff --git a/cli/src/main.rs b/cli/src/main.rs index 6293a2c..2636de5 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -81,8 +81,7 @@ async fn generate_tx_proof(tx_hash: &str, rpc_url: &str) -> Result<(), EthTrieEr let root = txs_mpt_handler.get_root()?; let mpt_proof = MptProof { root, proof, index }; - println!("Generated TX Proof: "); - println!("{}", serde_json::to_string(&mpt_proof).unwrap()); + print!("{}", serde_json::to_string(&mpt_proof).unwrap()); Ok(()) } @@ -98,7 +97,6 @@ async fn generate_receipt_proof(tx_hash: &str, rpc_url: &str) -> Result<(), EthT let root = tx_receipts_mpt_handler.get_root()?; let mpt_proof = MptProof { root, proof, index }; - println!("Generated Receipt Proof: "); - println!("{}", serde_json::to_string(&mpt_proof).unwrap()); + print!("{}", serde_json::to_string(&mpt_proof).unwrap()); Ok(()) } diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..1643ef8 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,5 @@ +# Trie Proof examples + +Available examples: + +- [Ethereum transaction inclusion proof (Solidity)](./eth/) diff --git a/examples/eth/.gitignore b/examples/eth/.gitignore new file mode 100644 index 0000000..85198aa --- /dev/null +++ b/examples/eth/.gitignore @@ -0,0 +1,14 @@ +# Compiler files +cache/ +out/ + +# Ignores development broadcast logs +!/broadcast +/broadcast/*/31337/ +/broadcast/**/dry-run/ + +# Docs +docs/ + +# Dotenv file +.env diff --git a/examples/eth/README.md b/examples/eth/README.md new file mode 100644 index 0000000..482cdb6 --- /dev/null +++ b/examples/eth/README.md @@ -0,0 +1,30 @@ +# Ethereum Trie Proof: Solidity example + +This example demonstrates how to verify an Ethereum transaction inclusion proof in Solidity. +The proof will be generated with the [`cli`](../../cli/) binary directly from the Foundry test. + +## Overview of the example + +The example consists of a single end2end test in [Prover.t.sol](./test/Prover.t.sol) that generates a Merkle +inclusion proof and verifies it in Solidity against a given transaction hash. + +## Usage + +Run the Foundry tests: + +```shell +# make sure you are in the right directory +cd examples/eth + +forge test --ffi +``` + +## Troubleshooting + +If the test fails with an RPC error, you can customize the RPC endpoint by setting the `RPC_URL` +environment variable to the desired endpoing in your shell. For example: + +```shell +export RPC_URL=https://mainnet.infura.io/v3/your_project_id +forge test --ffi +``` diff --git a/examples/eth/foundry.toml b/examples/eth/foundry.toml new file mode 100644 index 0000000..25b918f --- /dev/null +++ b/examples/eth/foundry.toml @@ -0,0 +1,6 @@ +[profile.default] +src = "src" +out = "out" +libs = ["lib"] + +# See more config options https://github.com/foundry-rs/foundry/blob/master/crates/config/README.md#all-options diff --git a/examples/eth/lib/forge-std b/examples/eth/lib/forge-std new file mode 160000 index 0000000..1714bee --- /dev/null +++ b/examples/eth/lib/forge-std @@ -0,0 +1 @@ +Subproject commit 1714bee72e286e73f76e320d110e0eaf5c4e649d diff --git a/examples/eth/src/Prover.sol b/examples/eth/src/Prover.sol new file mode 100644 index 0000000..faccacf --- /dev/null +++ b/examples/eth/src/Prover.sol @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: MIT +pragma solidity 0.8.25; + +import {MerkleTrie} from "./lib/MerkleTrie.sol"; + +contract Prover { + constructor() {} + + function get(bytes memory _key, bytes memory _proof, bytes32 _root) public pure returns (bool, bytes memory) { + return MerkleTrie.get(_key, _proof, _root); + } +} diff --git a/examples/eth/src/lib/BytesUtils.sol b/examples/eth/src/lib/BytesUtils.sol new file mode 100644 index 0000000..46e86f9 --- /dev/null +++ b/examples/eth/src/lib/BytesUtils.sol @@ -0,0 +1,179 @@ +// SPDX-License-Identifier: MIT +pragma solidity 0.8.25; + +/** + * @title BytesUtils + */ +library BytesUtils { + /** + * + * Internal Functions * + * + */ + function slice(bytes memory _bytes, uint256 _start, uint256 _length) internal pure returns (bytes memory) { + unchecked { + require(_length + 31 >= _length, "slice_overflow"); + require(_start + _length >= _start, "slice_overflow"); + require(_bytes.length >= _start + _length, "slice_outOfBounds"); + + bytes memory tempBytes; + + assembly { + switch iszero(_length) + case 0 { + // Get a location of some free memory and store it in tempBytes as + // Solidity does for memory variables. + tempBytes := mload(0x40) + + // The first word of the slice result is potentially a partial + // word read from the original array. To read it, we calculate + // the length of that partial word and start copying that many + // bytes into the array. The first word we copy will start with + // data we don't care about, but the last `lengthmod` bytes will + // land at the beginning of the contents of the new array. When + // we're done copying, we overwrite the full first word with + // the actual length of the slice. + let lengthmod := and(_length, 31) + + // The multiplication in the next line is necessary + // because when slicing multiples of 32 bytes (lengthmod == 0) + // the following copy loop was copying the origin's length + // and then ending prematurely not copying everything it should. + let mc := add(add(tempBytes, lengthmod), mul(0x20, iszero(lengthmod))) + let end := add(mc, _length) + + for { + // The multiplication in the next line has the same exact purpose + // as the one above. + let cc := add(add(add(_bytes, lengthmod), mul(0x20, iszero(lengthmod))), _start) + } lt(mc, end) { + mc := add(mc, 0x20) + cc := add(cc, 0x20) + } { mstore(mc, mload(cc)) } + + mstore(tempBytes, _length) + + //update free-memory pointer + //allocating the array padded to 32 bytes like the compiler does now + mstore(0x40, and(add(mc, 31), not(31))) + } + //if we want a zero-length slice let's just return a zero-length array + default { + tempBytes := mload(0x40) + + //zero out the 32 bytes slice we are about to return + //we need to do it because Solidity does not garbage collect + mstore(tempBytes, 0) + + mstore(0x40, add(tempBytes, 0x20)) + } + } + + return tempBytes; + } + } + + function slice(bytes memory _bytes, uint256 _start) internal pure returns (bytes memory) { + unchecked { + if (_bytes.length - _start == 0) { + return bytes(""); + } + + return slice(_bytes, _start, _bytes.length - _start); + } + } + + function toBytes32PadLeft(bytes memory _bytes) internal pure returns (bytes32) { + unchecked { + bytes32 ret; + uint256 len = _bytes.length <= 32 ? _bytes.length : 32; + assembly { + ret := shr(mul(sub(32, len), 8), mload(add(_bytes, 32))) + } + return ret; + } + } + + function toBytes32(bytes memory _bytes) internal pure returns (bytes32) { + unchecked { + if (_bytes.length < 32) { + bytes32 ret; + assembly { + ret := mload(add(_bytes, 32)) + } + return ret; + } + + return abi.decode(_bytes, (bytes32)); // will truncate if input length > 32 bytes + } + } + + function toUint256(bytes memory _bytes) internal pure returns (uint256) { + return uint256(toBytes32(_bytes)); + } + + function toUint24(bytes memory _bytes, uint256 _start) internal pure returns (uint24) { + require(_start + 3 >= _start, "toUint24_overflow"); + require(_bytes.length >= _start + 3, "toUint24_outOfBounds"); + uint24 tempUint; + + assembly { + tempUint := mload(add(add(_bytes, 0x3), _start)) + } + + return tempUint; + } + + function toUint8(bytes memory _bytes, uint256 _start) internal pure returns (uint8) { + require(_start + 1 >= _start, "toUint8_overflow"); + require(_bytes.length >= _start + 1, "toUint8_outOfBounds"); + uint8 tempUint; + + assembly { + tempUint := mload(add(add(_bytes, 0x1), _start)) + } + + return tempUint; + } + + function toAddress(bytes memory _bytes, uint256 _start) internal pure returns (address) { + require(_start + 20 >= _start, "toAddress_overflow"); + require(_bytes.length >= _start + 20, "toAddress_outOfBounds"); + address tempAddress; + + assembly { + tempAddress := div(mload(add(add(_bytes, 0x20), _start)), 0x1000000000000000000000000) + } + + return tempAddress; + } + + function toNibbles(bytes memory _bytes) internal pure returns (bytes memory) { + unchecked { + bytes memory nibbles = new bytes(_bytes.length * 2); + + for (uint256 i = 0; i < _bytes.length; i++) { + nibbles[i * 2] = _bytes[i] >> 4; + nibbles[i * 2 + 1] = bytes1(uint8(_bytes[i]) % 16); + } + + return nibbles; + } + } + + function fromNibbles(bytes memory _bytes) internal pure returns (bytes memory) { + unchecked { + bytes memory ret = new bytes(_bytes.length / 2); + + for (uint256 i = 0; i < ret.length; i++) { + ret[i] = (_bytes[i * 2] << 4) | (_bytes[i * 2 + 1]); + } + + return ret; + } + } + + function equal(bytes memory _bytes, bytes memory _other) internal pure returns (bool) { + return keccak256(_bytes) == keccak256(_other); + } +} diff --git a/examples/eth/src/lib/MerkleTrie.sol b/examples/eth/src/lib/MerkleTrie.sol new file mode 100644 index 0000000..f54caa0 --- /dev/null +++ b/examples/eth/src/lib/MerkleTrie.sol @@ -0,0 +1,737 @@ +// SPDX-License-Identifier: MIT +pragma solidity 0.8.25; + +/* Library Imports */ +import {BytesUtils} from "./BytesUtils.sol"; +import {RLPReader} from "./RLPReader.sol"; +import {RLPWriter} from "./RLPWriter.sol"; + +/** + * @title MerkleTrie + */ +library MerkleTrie { + /** + * + * Data Structures * + * + */ + enum NodeType { + BranchNode, + ExtensionNode, + LeafNode + } + + struct TrieNode { + bytes encoded; + RLPReader.RLPItem[] decoded; + } + + /** + * + * Contract Constants * + * + */ + + // TREE_RADIX determines the number of elements per branch node. + uint256 constant TREE_RADIX = 16; + // Branch nodes have TREE_RADIX elements plus an additional `value` slot. + uint256 constant BRANCH_NODE_LENGTH = TREE_RADIX + 1; + // Leaf nodes and extension nodes always have two elements, a `path` and a `value`. + uint256 constant LEAF_OR_EXTENSION_NODE_LENGTH = 2; + + // Prefixes are prepended to the `path` within a leaf or extension node and + // allow us to differentiate between the two node types. `ODD` or `EVEN` is + // determined by the number of nibbles within the unprefixed `path`. If the + // number of nibbles if even, we need to insert an extra padding nibble so + // the resulting prefixed `path` has an even number of nibbles. + uint8 constant PREFIX_EXTENSION_EVEN = 0; + uint8 constant PREFIX_EXTENSION_ODD = 1; + uint8 constant PREFIX_LEAF_EVEN = 2; + uint8 constant PREFIX_LEAF_ODD = 3; + + // Just a utility constant. RLP represents `NULL` as 0x80. + bytes1 constant RLP_NULL = bytes1(0x80); + bytes constant RLP_NULL_BYTES = hex"80"; + bytes32 internal constant KECCAK256_RLP_NULL_BYTES = keccak256(RLP_NULL_BYTES); + + /** + * + * Internal Functions * + * + */ + + /** + * @notice Verifies a proof that a given key/value pair is present in the + * Merkle trie. + * @param _key Key of the node to search for, as a hex string. + * @param _value Value of the node to search for, as a hex string. + * @param _proof Merkle trie inclusion proof for the desired node. Unlike + * traditional Merkle trees, this proof is executed top-down and consists + * of a list of RLP-encoded nodes that make a path down to the target node. + * @param _root Known root of the Merkle trie. Used to verify that the + * included proof is correctly constructed. + * @return _verified `true` if the k/v pair exists in the trie, `false` otherwise. + */ + function verifyInclusionProof(bytes memory _key, bytes memory _value, bytes memory _proof, bytes32 _root) + internal + pure + returns (bool _verified) + { + (bool exists, bytes memory value) = get(_key, _proof, _root); + + return (exists && BytesUtils.equal(_value, value)); + } + + /** + * @notice Verifies a proof that a given key is *not* present in + * the Merkle trie. + * @param _key Key of the node to search for, as a hex string. + * @param _proof Merkle trie inclusion proof for the node *nearest* the + * target node. + * @param _root Known root of the Merkle trie. Used to verify that the + * included proof is correctly constructed. + * @return _verified `true` if the key is absent in the trie, `false` otherwise. + */ + function verifyExclusionProof(bytes memory _key, bytes memory _proof, bytes32 _root) + internal + pure + returns (bool _verified) + { + (bool exists,) = get(_key, _proof, _root); + + return exists == false; + } + + /** + * @notice Updates a Merkle trie and returns a new root hash. + * @param _key Key of the node to update, as a hex string. + * @param _value Value of the node to update, as a hex string. + * @param _proof Merkle trie inclusion proof for the node *nearest* the + * target node. If the key exists, we can simply update the value. + * Otherwise, we need to modify the trie to handle the new k/v pair. + * @param _root Known root of the Merkle trie. Used to verify that the + * included proof is correctly constructed. + * @return _updatedRoot Root hash of the newly constructed trie. + */ + function update(bytes memory _key, bytes memory _value, bytes memory _proof, bytes32 _root) + internal + pure + returns (bytes32 _updatedRoot) + { + // Special case when inserting the very first node. + if (_root == KECCAK256_RLP_NULL_BYTES) { + return getSingleNodeRootHash(_key, _value); + } + + TrieNode[] memory proof = _parseProof(_proof); + (uint256 pathLength, bytes memory keyRemainder,) = _walkNodePath(proof, _key, _root); + TrieNode[] memory newPath = _getNewPath(proof, pathLength, keyRemainder, _value); + + return _getUpdatedTrieRoot(newPath, _key); + } + + /** + * @notice Retrieves the value associated with a given key. + * @param _key Key to search for, as hex bytes. + * @param _proof Merkle trie inclusion proof for the key. + * @param _root Known root of the Merkle trie. + * @return _exists Whether or not the key exists. + * @return _value Value of the key if it exists. + */ + function get(bytes memory _key, bytes memory _proof, bytes32 _root) + internal + pure + returns (bool _exists, bytes memory _value) + { + TrieNode[] memory proof = _parseProof(_proof); + (uint256 pathLength, bytes memory keyRemainder, bool isFinalNode) = _walkNodePath(proof, _key, _root); + + bool exists = keyRemainder.length == 0; + + require(exists || isFinalNode, "Provided proof is invalid."); + + bytes memory value = exists ? _getNodeValue(proof[pathLength - 1]) : bytes(""); + + return (exists, value); + } + + /** + * Computes the root hash for a trie with a single node. + * @param _key Key for the single node. + * @param _value Value for the single node. + * @return _updatedRoot Hash of the trie. + */ + function getSingleNodeRootHash(bytes memory _key, bytes memory _value) + internal + pure + returns (bytes32 _updatedRoot) + { + return keccak256(_makeLeafNode(BytesUtils.toNibbles(_key), _value).encoded); + } + + /** + * + * Private Functions * + * + */ + + /** + * @notice Walks through a proof using a provided key. + * @param _proof Inclusion proof to walk through. + * @param _key Key to use for the walk. + * @param _root Known root of the trie. + * @return _pathLength Length of the final path + * @return _keyRemainder Portion of the key remaining after the walk. + * @return _isFinalNode Whether or not we've hit a dead end. + */ + function _walkNodePath(TrieNode[] memory _proof, bytes memory _key, bytes32 _root) + private + pure + returns (uint256 _pathLength, bytes memory _keyRemainder, bool _isFinalNode) + { + uint256 pathLength = 0; + bytes memory key = BytesUtils.toNibbles(_key); + + bytes32 currentNodeID = _root; + uint256 currentKeyIndex = 0; + uint256 currentKeyIncrement = 0; + TrieNode memory currentNode; + + // Proof is top-down, so we start at the first element (root). + for (uint256 i = 0; i < _proof.length; i++) { + currentNode = _proof[i]; + currentKeyIndex += currentKeyIncrement; + + // Keep track of the proof elements we actually need. + // It's expensive to resize arrays, so this simply reduces gas costs. + pathLength += 1; + + if (currentKeyIndex == 0) { + // First proof element is always the root node. + require(keccak256(currentNode.encoded) == currentNodeID, "Invalid root hash"); + } else if (currentNode.encoded.length >= 32) { + // Nodes 32 bytes or larger are hashed inside branch nodes. + require(keccak256(currentNode.encoded) == currentNodeID, "Invalid large internal hash"); + } else { + // Nodes smaller than 31 bytes aren't hashed. + require(BytesUtils.toBytes32(currentNode.encoded) == currentNodeID, "Invalid internal node hash"); + } + + if (currentNode.decoded.length == BRANCH_NODE_LENGTH) { + if (currentKeyIndex == key.length) { + // We've hit the end of the key, meaning the value should be within this branch node. + break; + } else { + // We're not at the end of the key yet. + // Figure out what the next node ID should be and continue. + uint8 branchKey = uint8(key[currentKeyIndex]); + RLPReader.RLPItem memory nextNode = currentNode.decoded[branchKey]; + currentNodeID = _getNodeID(nextNode); + currentKeyIncrement = 1; + continue; + } + } else if (currentNode.decoded.length == LEAF_OR_EXTENSION_NODE_LENGTH) { + bytes memory path = _getNodePath(currentNode); + uint8 prefix = uint8(path[0]); + uint8 offset = 2 - prefix % 2; + bytes memory pathRemainder = BytesUtils.slice(path, offset); + bytes memory keyRemainder = BytesUtils.slice(key, currentKeyIndex); + uint256 sharedNibbleLength = _getSharedNibbleLength(pathRemainder, keyRemainder); + + if (prefix == PREFIX_LEAF_EVEN || prefix == PREFIX_LEAF_ODD) { + if (pathRemainder.length == sharedNibbleLength && keyRemainder.length == sharedNibbleLength) { + // The key within this leaf matches our key exactly. + // Increment the key index to reflect that we have no remainder. + currentKeyIndex += sharedNibbleLength; + } + + // We've hit a leaf node, so our next node should be NULL. + currentNodeID = bytes32(RLP_NULL); + break; + } else if (prefix == PREFIX_EXTENSION_EVEN || prefix == PREFIX_EXTENSION_ODD) { + if (sharedNibbleLength == 0) { + // Our extension node doesn't share any part of our key. + // We've hit the end of this path, updates will need to modify this extension. + currentNodeID = bytes32(RLP_NULL); + break; + } else { + // Our extension shares some nibbles. + // Carry on to the next node. + currentNodeID = _getNodeID(currentNode.decoded[1]); + currentKeyIncrement = sharedNibbleLength; + continue; + } + } else { + revert("Received a node with an unknown prefix"); + } + } else { + revert("Received an unparseable node."); + } + } + + // If our node ID is NULL, then we're at a dead end. + bool isFinalNode = currentNodeID == bytes32(RLP_NULL); + return (pathLength, BytesUtils.slice(key, currentKeyIndex), isFinalNode); + } + + /** + * @notice Creates new nodes to support a k/v pair insertion into a given + * Merkle trie path. + * @param _path Path to the node nearest the k/v pair. + * @param _pathLength Length of the path. Necessary because the provided + * path may include additional nodes (e.g., it comes directly from a proof) + * and we can't resize in-memory arrays without costly duplication. + * @param _keyRemainder Portion of the initial key that must be inserted + * into the trie. + * @param _value Value to insert at the given key. + * @return _newPath A new path with the inserted k/v pair and extra supporting nodes. + */ + function _getNewPath(TrieNode[] memory _path, uint256 _pathLength, bytes memory _keyRemainder, bytes memory _value) + private + pure + returns (TrieNode[] memory _newPath) + { + bytes memory keyRemainder = _keyRemainder; + + // Most of our logic depends on the status of the last node in the path. + TrieNode memory lastNode = _path[_pathLength - 1]; + NodeType lastNodeType = _getNodeType(lastNode); + + // Create an array for newly created nodes. + // We need up to three new nodes, depending on the contents of the last node. + // Since array resizing is expensive, we'll keep track of the size manually. + // We're using an explicit `totalNewNodes += 1` after insertions for clarity. + TrieNode[] memory newNodes = new TrieNode[](3); + uint256 totalNewNodes = 0; + + if (keyRemainder.length == 0 && lastNodeType == NodeType.LeafNode) { + // We've found a leaf node with the given key. + // Simply need to update the value of the node to match. + newNodes[totalNewNodes] = _makeLeafNode(_getNodeKey(lastNode), _value); + totalNewNodes += 1; + } else if (lastNodeType == NodeType.BranchNode) { + if (keyRemainder.length == 0) { + // We've found a branch node with the given key. + // Simply need to update the value of the node to match. + newNodes[totalNewNodes] = _editBranchValue(lastNode, _value); + totalNewNodes += 1; + } else { + // We've found a branch node, but it doesn't contain our key. + // Reinsert the old branch for now. + newNodes[totalNewNodes] = lastNode; + totalNewNodes += 1; + // Create a new leaf node, slicing our remainder since the first byte points + // to our branch node. + newNodes[totalNewNodes] = _makeLeafNode(BytesUtils.slice(keyRemainder, 1), _value); + totalNewNodes += 1; + } + } else { + // Our last node is either an extension node or a leaf node with a different key. + bytes memory lastNodeKey = _getNodeKey(lastNode); + uint256 sharedNibbleLength = _getSharedNibbleLength(lastNodeKey, keyRemainder); + + if (sharedNibbleLength != 0) { + // We've got some shared nibbles between the last node and our key remainder. + // We'll need to insert an extension node that covers these shared nibbles. + bytes memory nextNodeKey = BytesUtils.slice(lastNodeKey, 0, sharedNibbleLength); + newNodes[totalNewNodes] = _makeExtensionNode(nextNodeKey, _getNodeHash(_value)); + totalNewNodes += 1; + + // Cut down the keys since we've just covered these shared nibbles. + lastNodeKey = BytesUtils.slice(lastNodeKey, sharedNibbleLength); + keyRemainder = BytesUtils.slice(keyRemainder, sharedNibbleLength); + } + + // Create an empty branch to fill in. + TrieNode memory newBranch = _makeEmptyBranchNode(); + + if (lastNodeKey.length == 0) { + // Key remainder was larger than the key for our last node. + // The value within our last node is therefore going to be shifted into + // a branch value slot. + newBranch = _editBranchValue(newBranch, _getNodeValue(lastNode)); + } else { + // Last node key was larger than the key remainder. + // We're going to modify some index of our branch. + uint8 branchKey = uint8(lastNodeKey[0]); + // Move on to the next nibble. + lastNodeKey = BytesUtils.slice(lastNodeKey, 1); + + if (lastNodeType == NodeType.LeafNode) { + // We're dealing with a leaf node. + // We'll modify the key and insert the old leaf node into the branch index. + TrieNode memory modifiedLastNode = _makeLeafNode(lastNodeKey, _getNodeValue(lastNode)); + newBranch = _editBranchIndex(newBranch, branchKey, _getNodeHash(modifiedLastNode.encoded)); + } else if (lastNodeKey.length != 0) { + // We're dealing with a shrinking extension node. + // We need to modify the node to decrease the size of the key. + TrieNode memory modifiedLastNode = _makeExtensionNode(lastNodeKey, _getNodeValue(lastNode)); + newBranch = _editBranchIndex(newBranch, branchKey, _getNodeHash(modifiedLastNode.encoded)); + } else { + // We're dealing with an unnecessary extension node. + // We're going to delete the node entirely. + // Simply insert its current value into the branch index. + newBranch = _editBranchIndex(newBranch, branchKey, _getNodeValue(lastNode)); + } + } + + if (keyRemainder.length == 0) { + // We've got nothing left in the key remainder. + // Simply insert the value into the branch value slot. + newBranch = _editBranchValue(newBranch, _value); + // Push the branch into the list of new nodes. + newNodes[totalNewNodes] = newBranch; + totalNewNodes += 1; + } else { + // We've got some key remainder to work with. + // We'll be inserting a leaf node into the trie. + // First, move on to the next nibble. + keyRemainder = BytesUtils.slice(keyRemainder, 1); + // Push the branch into the list of new nodes. + newNodes[totalNewNodes] = newBranch; + totalNewNodes += 1; + // Push a new leaf node for our k/v pair. + newNodes[totalNewNodes] = _makeLeafNode(keyRemainder, _value); + totalNewNodes += 1; + } + } + + // Finally, join the old path with our newly created nodes. + // Since we're overwriting the last node in the path, we use `_pathLength - 1`. + return _joinNodeArrays(_path, _pathLength - 1, newNodes, totalNewNodes); + } + + /** + * @notice Computes the trie root from a given path. + * @param _nodes Path to some k/v pair. + * @param _key Key for the k/v pair. + * @return _updatedRoot Root hash for the updated trie. + */ + function _getUpdatedTrieRoot(TrieNode[] memory _nodes, bytes memory _key) + private + pure + returns (bytes32 _updatedRoot) + { + bytes memory key = BytesUtils.toNibbles(_key); + + // Some variables to keep track of during iteration. + TrieNode memory currentNode; + NodeType currentNodeType; + bytes memory previousNodeHash; + + // Run through the path backwards to rebuild our root hash. + for (uint256 i = _nodes.length; i > 0; i--) { + // Pick out the current node. + currentNode = _nodes[i - 1]; + currentNodeType = _getNodeType(currentNode); + + if (currentNodeType == NodeType.LeafNode) { + // Leaf nodes are already correctly encoded. + // Shift the key over to account for the nodes key. + bytes memory nodeKey = _getNodeKey(currentNode); + key = BytesUtils.slice(key, 0, key.length - nodeKey.length); + } else if (currentNodeType == NodeType.ExtensionNode) { + // Shift the key over to account for the nodes key. + bytes memory nodeKey = _getNodeKey(currentNode); + key = BytesUtils.slice(key, 0, key.length - nodeKey.length); + + // If this node is the last element in the path, it'll be correctly encoded + // and we can skip this part. + if (previousNodeHash.length > 0) { + // Re-encode the node based on the previous node. + currentNode = _makeExtensionNode(nodeKey, previousNodeHash); + } + } else if (currentNodeType == NodeType.BranchNode) { + // If this node is the last element in the path, it'll be correctly encoded + // and we can skip this part. + if (previousNodeHash.length > 0) { + // Re-encode the node based on the previous node. + uint8 branchKey = uint8(key[key.length - 1]); + key = BytesUtils.slice(key, 0, key.length - 1); + currentNode = _editBranchIndex(currentNode, branchKey, previousNodeHash); + } + } + + // Compute the node hash for the next iteration. + previousNodeHash = _getNodeHash(currentNode.encoded); + } + + // Current node should be the root at this point. + // Simply return the hash of its encoding. + return keccak256(currentNode.encoded); + } + + /** + * @notice Parses an RLP-encoded proof into something more useful. + * @param _proof RLP-encoded proof to parse. + * @return _parsed Proof parsed into easily accessible structs. + */ + function _parseProof(bytes memory _proof) private pure returns (TrieNode[] memory _parsed) { + RLPReader.RLPItem[] memory nodes = RLPReader.readList(_proof); + TrieNode[] memory proof = new TrieNode[](nodes.length); + + for (uint256 i = 0; i < nodes.length; i++) { + bytes memory encoded = RLPReader.readBytes(nodes[i]); + proof[i] = TrieNode({encoded: encoded, decoded: RLPReader.readList(encoded)}); + } + + return proof; + } + + /** + * @notice Picks out the ID for a node. Node ID is referred to as the + * "hash" within the specification, but nodes < 32 bytes are not actually + * hashed. + * @param _node Node to pull an ID for. + * @return _nodeID ID for the node, depending on the size of its contents. + */ + function _getNodeID(RLPReader.RLPItem memory _node) private pure returns (bytes32 _nodeID) { + bytes memory nodeID; + + if (_node.length < 32) { + // Nodes smaller than 32 bytes are RLP encoded. + nodeID = RLPReader.readRawBytes(_node); + } else { + // Nodes 32 bytes or larger are hashed. + nodeID = RLPReader.readBytes(_node); + } + + return BytesUtils.toBytes32(nodeID); + } + + /** + * @notice Gets the path for a leaf or extension node. + * @param _node Node to get a path for. + * @return _path Node path, converted to an array of nibbles. + */ + function _getNodePath(TrieNode memory _node) private pure returns (bytes memory _path) { + return BytesUtils.toNibbles(RLPReader.readBytes(_node.decoded[0])); + } + + /** + * @notice Gets the key for a leaf or extension node. Keys are essentially + * just paths without any prefix. + * @param _node Node to get a key for. + * @return _key Node key, converted to an array of nibbles. + */ + function _getNodeKey(TrieNode memory _node) private pure returns (bytes memory _key) { + return _removeHexPrefix(_getNodePath(_node)); + } + + /** + * @notice Gets the path for a node. + * @param _node Node to get a value for. + * @return _value Node value, as hex bytes. + */ + function _getNodeValue(TrieNode memory _node) private pure returns (bytes memory _value) { + return RLPReader.readBytes(_node.decoded[_node.decoded.length - 1]); + } + + /** + * @notice Computes the node hash for an encoded node. Nodes < 32 bytes + * are not hashed, all others are keccak256 hashed. + * @param _encoded Encoded node to hash. + * @return _hash Hash of the encoded node. Simply the input if < 32 bytes. + */ + function _getNodeHash(bytes memory _encoded) private pure returns (bytes memory _hash) { + if (_encoded.length < 32) { + return _encoded; + } else { + return abi.encodePacked(keccak256(_encoded)); + } + } + + /** + * @notice Determines the type for a given node. + * @param _node Node to determine a type for. + * @return _type Type of the node; BranchNode/ExtensionNode/LeafNode. + */ + function _getNodeType(TrieNode memory _node) private pure returns (NodeType _type) { + if (_node.decoded.length == BRANCH_NODE_LENGTH) { + return NodeType.BranchNode; + } else if (_node.decoded.length == LEAF_OR_EXTENSION_NODE_LENGTH) { + bytes memory path = _getNodePath(_node); + uint8 prefix = uint8(path[0]); + + if (prefix == PREFIX_LEAF_EVEN || prefix == PREFIX_LEAF_ODD) { + return NodeType.LeafNode; + } else if (prefix == PREFIX_EXTENSION_EVEN || prefix == PREFIX_EXTENSION_ODD) { + return NodeType.ExtensionNode; + } + } + + revert("Invalid node type"); + } + + /** + * @notice Utility; determines the number of nibbles shared between two + * nibble arrays. + * @param _a First nibble array. + * @param _b Second nibble array. + * @return _shared Number of shared nibbles. + */ + function _getSharedNibbleLength(bytes memory _a, bytes memory _b) private pure returns (uint256 _shared) { + uint256 i = 0; + while (_a.length > i && _b.length > i && _a[i] == _b[i]) { + i++; + } + return i; + } + + /** + * @notice Utility; converts an RLP-encoded node into our nice struct. + * @param _raw RLP-encoded node to convert. + * @return _node Node as a TrieNode struct. + */ + function _makeNode(bytes[] memory _raw) private pure returns (TrieNode memory _node) { + bytes memory encoded = RLPWriter.writeList(_raw); + + return TrieNode({encoded: encoded, decoded: RLPReader.readList(encoded)}); + } + + /** + * @notice Utility; converts an RLP-decoded node into our nice struct. + * @param _items RLP-decoded node to convert. + * @return _node Node as a TrieNode struct. + */ + function _makeNode(RLPReader.RLPItem[] memory _items) private pure returns (TrieNode memory _node) { + bytes[] memory raw = new bytes[](_items.length); + for (uint256 i = 0; i < _items.length; i++) { + raw[i] = RLPReader.readRawBytes(_items[i]); + } + return _makeNode(raw); + } + + /** + * @notice Creates a new extension node. + * @param _key Key for the extension node, unprefixed. + * @param _value Value for the extension node. + * @return _node New extension node with the given k/v pair. + */ + function _makeExtensionNode(bytes memory _key, bytes memory _value) private pure returns (TrieNode memory _node) { + bytes[] memory raw = new bytes[](2); + bytes memory key = _addHexPrefix(_key, false); + raw[0] = RLPWriter.writeBytes(BytesUtils.fromNibbles(key)); + raw[1] = RLPWriter.writeBytes(_value); + return _makeNode(raw); + } + + /** + * @notice Creates a new leaf node. + * @dev This function is essentially identical to `_makeExtensionNode`. + * Although we could route both to a single method with a flag, it's + * more gas efficient to keep them separate and duplicate the logic. + * @param _key Key for the leaf node, unprefixed. + * @param _value Value for the leaf node. + * @return _node New leaf node with the given k/v pair. + */ + function _makeLeafNode(bytes memory _key, bytes memory _value) private pure returns (TrieNode memory _node) { + bytes[] memory raw = new bytes[](2); + bytes memory key = _addHexPrefix(_key, true); + raw[0] = RLPWriter.writeBytes(BytesUtils.fromNibbles(key)); + raw[1] = RLPWriter.writeBytes(_value); + return _makeNode(raw); + } + + /** + * @notice Creates an empty branch node. + * @return _node Empty branch node as a TrieNode struct. + */ + function _makeEmptyBranchNode() private pure returns (TrieNode memory _node) { + bytes[] memory raw = new bytes[](BRANCH_NODE_LENGTH); + for (uint256 i = 0; i < raw.length; i++) { + raw[i] = RLP_NULL_BYTES; + } + return _makeNode(raw); + } + + /** + * @notice Modifies the value slot for a given branch. + * @param _branch Branch node to modify. + * @param _value Value to insert into the branch. + * @return _updatedNode Modified branch node. + */ + function _editBranchValue(TrieNode memory _branch, bytes memory _value) + private + pure + returns (TrieNode memory _updatedNode) + { + bytes memory encoded = RLPWriter.writeBytes(_value); + _branch.decoded[_branch.decoded.length - 1] = RLPReader.toRLPItem(encoded); + return _makeNode(_branch.decoded); + } + + /** + * @notice Modifies a slot at an index for a given branch. + * @param _branch Branch node to modify. + * @param _index Slot index to modify. + * @param _value Value to insert into the slot. + * @return _updatedNode Modified branch node. + */ + function _editBranchIndex(TrieNode memory _branch, uint8 _index, bytes memory _value) + private + pure + returns (TrieNode memory _updatedNode) + { + bytes memory encoded = _value.length < 32 ? _value : RLPWriter.writeBytes(_value); + _branch.decoded[_index] = RLPReader.toRLPItem(encoded); + return _makeNode(_branch.decoded); + } + + /** + * @notice Utility; adds a prefix to a key. + * @param _key Key to prefix. + * @param _isLeaf Whether or not the key belongs to a leaf. + * @return _prefixedKey Prefixed key. + */ + function _addHexPrefix(bytes memory _key, bool _isLeaf) private pure returns (bytes memory _prefixedKey) { + uint8 prefix = _isLeaf ? uint8(0x02) : uint8(0x00); + uint8 offset = uint8(_key.length % 2); + bytes memory prefixed = new bytes(2 - offset); + prefixed[0] = bytes1(prefix + offset); + return abi.encodePacked(prefixed, _key); + } + + /** + * @notice Utility; removes a prefix from a path. + * @param _path Path to remove the prefix from. + * @return _unprefixedKey Unprefixed key. + */ + function _removeHexPrefix(bytes memory _path) private pure returns (bytes memory _unprefixedKey) { + if (uint8(_path[0]) % 2 == 0) { + return BytesUtils.slice(_path, 2); + } else { + return BytesUtils.slice(_path, 1); + } + } + + /** + * @notice Utility; combines two node arrays. Array lengths are required + * because the actual lengths may be longer than the filled lengths. + * Array resizing is extremely costly and should be avoided. + * @param _a First array to join. + * @param _aLength Length of the first array. + * @param _b Second array to join. + * @param _bLength Length of the second array. + * @return _joined Combined node array. + */ + function _joinNodeArrays(TrieNode[] memory _a, uint256 _aLength, TrieNode[] memory _b, uint256 _bLength) + private + pure + returns (TrieNode[] memory _joined) + { + TrieNode[] memory ret = new TrieNode[](_aLength + _bLength); + + // Copy elements from the first array. + for (uint256 i = 0; i < _aLength; i++) { + ret[i] = _a[i]; + } + + // Copy elements from the second array. + for (uint256 i = 0; i < _bLength; i++) { + ret[i + _aLength] = _b[i]; + } + + return ret; + } +} diff --git a/examples/eth/src/lib/RLPReader.sol b/examples/eth/src/lib/RLPReader.sol new file mode 100644 index 0000000..339c9fd --- /dev/null +++ b/examples/eth/src/lib/RLPReader.sol @@ -0,0 +1,379 @@ +// SPDX-License-Identifier: MIT +pragma solidity 0.8.25; + +/** + * @title RLPReader + * @dev Adapted from "RLPReader" by Hamdi Allam (hamdi.allam97@gmail.com). + */ +library RLPReader { + /** + * + * Constants * + * + */ + uint256 internal constant MAX_LIST_LENGTH = 32; + + /** + * + * Enums * + * + */ + enum RLPItemType { + DATA_ITEM, + LIST_ITEM + } + + /** + * + * Structs * + * + */ + struct RLPItem { + uint256 length; + uint256 ptr; + } + + /** + * + * Internal Functions * + * + */ + + /** + * Converts bytes to a reference to memory position and length. + * @param _in Input bytes to convert. + * @return Output memory reference. + */ + function toRLPItem(bytes memory _in) internal pure returns (RLPItem memory) { + uint256 ptr; + assembly { + ptr := add(_in, 32) + } + + return RLPItem({length: _in.length, ptr: ptr}); + } + + /** + * Reads an RLP list value into a list of RLP items. + * @param _in RLP list value. + * @return Decoded RLP list items. + */ + function readList(RLPItem memory _in) internal pure returns (RLPItem[] memory) { + (uint256 listOffset,, RLPItemType itemType) = _decodeLength(_in); + + require(itemType == RLPItemType.LIST_ITEM, "Invalid RLP list value."); + + // Solidity in-memory arrays can't be increased in size, but *can* be decreased in size by + // writing to the length. Since we can't know the number of RLP items without looping over + // the entire input, we'd have to loop twice to accurately size this array. It's easier to + // simply set a reasonable maximum list length and decrease the size before we finish. + RLPItem[] memory out = new RLPItem[](MAX_LIST_LENGTH); + + uint256 itemCount = 0; + uint256 offset = listOffset; + while (offset < _in.length) { + require(itemCount < MAX_LIST_LENGTH, "Provided RLP list exceeds max list length."); + + (uint256 itemOffset, uint256 itemLength,) = + _decodeLength(RLPItem({length: _in.length - offset, ptr: _in.ptr + offset})); + + out[itemCount] = RLPItem({length: itemLength + itemOffset, ptr: _in.ptr + offset}); + + itemCount += 1; + offset += itemOffset + itemLength; + } + + // Decrease the array size to match the actual item count. + assembly { + mstore(out, itemCount) + } + + return out; + } + + /** + * Reads an RLP list value into a list of RLP items. + * @param _in RLP list value. + * @return Decoded RLP list items. + */ + function readList(bytes memory _in) internal pure returns (RLPItem[] memory) { + return readList(toRLPItem(_in)); + } + + /** + * Reads an RLP bytes value into bytes. + * @param _in RLP bytes value. + * @return Decoded bytes. + */ + function readBytes(RLPItem memory _in) internal pure returns (bytes memory) { + (uint256 itemOffset, uint256 itemLength, RLPItemType itemType) = _decodeLength(_in); + + require(itemType == RLPItemType.DATA_ITEM, "Invalid RLP bytes value."); + + return _copy(_in.ptr, itemOffset, itemLength); + } + + /** + * Reads an RLP bytes value into bytes. + * @param _in RLP bytes value. + * @return Decoded bytes. + */ + function readBytes(bytes memory _in) internal pure returns (bytes memory) { + return readBytes(toRLPItem(_in)); + } + + /** + * Reads an RLP string value into a string. + * @param _in RLP string value. + * @return Decoded string. + */ + function readString(RLPItem memory _in) internal pure returns (string memory) { + return string(readBytes(_in)); + } + + /** + * Reads an RLP string value into a string. + * @param _in RLP string value. + * @return Decoded string. + */ + function readString(bytes memory _in) internal pure returns (string memory) { + return readString(toRLPItem(_in)); + } + + /** + * Reads an RLP bytes32 value into a bytes32. + * @param _in RLP bytes32 value. + * @return Decoded bytes32. + */ + function readBytes32(RLPItem memory _in) internal pure returns (bytes32) { + require(_in.length <= 33, "Invalid RLP bytes32 value."); + + (uint256 itemOffset, uint256 itemLength, RLPItemType itemType) = _decodeLength(_in); + + require(itemType == RLPItemType.DATA_ITEM, "Invalid RLP bytes32 value."); + + uint256 ptr = _in.ptr + itemOffset; + bytes32 out; + assembly { + out := mload(ptr) + + // Shift the bytes over to match the item size. + if lt(itemLength, 32) { out := div(out, exp(256, sub(32, itemLength))) } + } + + return out; + } + + /** + * Reads an RLP bytes32 value into a bytes32. + * @param _in RLP bytes32 value. + * @return Decoded bytes32. + */ + function readBytes32(bytes memory _in) internal pure returns (bytes32) { + return readBytes32(toRLPItem(_in)); + } + + /** + * Reads an RLP uint256 value into a uint256. + * @param _in RLP uint256 value. + * @return Decoded uint256. + */ + function readUint256(RLPItem memory _in) internal pure returns (uint256) { + return uint256(readBytes32(_in)); + } + + /** + * Reads an RLP uint256 value into a uint256. + * @param _in RLP uint256 value. + * @return Decoded uint256. + */ + function readUint256(bytes memory _in) internal pure returns (uint256) { + return readUint256(toRLPItem(_in)); + } + + /** + * Reads an RLP bool value into a bool. + * @param _in RLP bool value. + * @return Decoded bool. + */ + function readBool(RLPItem memory _in) internal pure returns (bool) { + require(_in.length == 1, "Invalid RLP boolean value."); + + uint256 ptr = _in.ptr; + uint256 out; + assembly { + out := byte(0, mload(ptr)) + } + + require(out == 0 || out == 1, "RLPReader: Invalid RLP boolean value, must be 0 or 1"); + + return out != 0; + } + + /** + * Reads an RLP bool value into a bool. + * @param _in RLP bool value. + * @return Decoded bool. + */ + function readBool(bytes memory _in) internal pure returns (bool) { + return readBool(toRLPItem(_in)); + } + + /** + * Reads an RLP address value into a address. + * @param _in RLP address value. + * @return Decoded address. + */ + function readAddress(RLPItem memory _in) internal pure returns (address) { + if (_in.length == 1) { + return address(0); + } + + require(_in.length == 21, "Invalid RLP address value."); + + return address(uint160(readUint256(_in))); + } + + /** + * Reads an RLP address value into a address. + * @param _in RLP address value. + * @return Decoded address. + */ + function readAddress(bytes memory _in) internal pure returns (address) { + return readAddress(toRLPItem(_in)); + } + + /** + * Reads the raw bytes of an RLP item. + * @param _in RLP item to read. + * @return Raw RLP bytes. + */ + function readRawBytes(RLPItem memory _in) internal pure returns (bytes memory) { + return _copy(_in); + } + + /** + * + * Private Functions * + * + */ + + /** + * Decodes the length of an RLP item. + * @param _in RLP item to decode. + * @return Offset of the encoded data. + * @return Length of the encoded data. + * @return RLP item type (LIST_ITEM or DATA_ITEM). + */ + function _decodeLength(RLPItem memory _in) private pure returns (uint256, uint256, RLPItemType) { + unchecked { + require(_in.length > 0, "RLP item cannot be null."); + + uint256 ptr = _in.ptr; + uint256 prefix; + assembly { + prefix := byte(0, mload(ptr)) + } + + if (prefix <= 0x7f) { + // Single byte. + + return (0, 1, RLPItemType.DATA_ITEM); + } else if (prefix <= 0xb7) { + // Short string. + + uint256 strLen = prefix - 0x80; + + require(_in.length > strLen, "Invalid RLP short string."); + + return (1, strLen, RLPItemType.DATA_ITEM); + } else if (prefix <= 0xbf) { + // Long string. + uint256 lenOfStrLen = prefix - 0xb7; + + require(_in.length > lenOfStrLen, "Invalid RLP long string length."); + + uint256 strLen; + assembly { + // Pick out the string length. + strLen := div(mload(add(ptr, 1)), exp(256, sub(32, lenOfStrLen))) + } + + require(_in.length > lenOfStrLen + strLen, "Invalid RLP long string."); + + return (1 + lenOfStrLen, strLen, RLPItemType.DATA_ITEM); + } else if (prefix <= 0xf7) { + // Short list. + uint256 listLen = prefix - 0xc0; + + require(_in.length > listLen, "Invalid RLP short list."); + + return (1, listLen, RLPItemType.LIST_ITEM); + } else { + // Long list. + uint256 lenOfListLen = prefix - 0xf7; + + require(_in.length > lenOfListLen, "Invalid RLP long list length."); + + uint256 listLen; + assembly { + // Pick out the list length. + listLen := div(mload(add(ptr, 1)), exp(256, sub(32, lenOfListLen))) + } + + require(_in.length > lenOfListLen + listLen, "Invalid RLP long list."); + + return (1 + lenOfListLen, listLen, RLPItemType.LIST_ITEM); + } + } + } + + /** + * Copies the bytes from a memory location. + * @param _src Pointer to the location to read from. + * @param _offset Offset to start reading from. + * @param _length Number of bytes to read. + * @return Copied bytes. + */ + function _copy(uint256 _src, uint256 _offset, uint256 _length) private pure returns (bytes memory) { + unchecked { + bytes memory out = new bytes(_length); + if (out.length == 0) { + return out; + } + + uint256 src = _src + _offset; + uint256 dest; + assembly { + dest := add(out, 32) + } + + // Copy over as many complete words as we can. + for (uint256 i = 0; i < _length / 32; i++) { + assembly { + mstore(dest, mload(src)) + } + + src += 32; + dest += 32; + } + + // Pick out the remaining bytes. + uint256 mask = 256 ** (32 - (_length % 32)) - 1; + assembly { + mstore(dest, or(and(mload(src), not(mask)), and(mload(dest), mask))) + } + + return out; + } + } + + /** + * Copies an RLP item into bytes. + * @param _in RLP item to copy. + * @return Copied bytes. + */ + function _copy(RLPItem memory _in) private pure returns (bytes memory) { + return _copy(_in.ptr, 0, _in.length); + } +} diff --git a/examples/eth/src/lib/RLPWriter.sol b/examples/eth/src/lib/RLPWriter.sol new file mode 100644 index 0000000..107683c --- /dev/null +++ b/examples/eth/src/lib/RLPWriter.sol @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: MIT +pragma solidity 0.8.25; +pragma experimental ABIEncoderV2; + +/** + * @title RLPWriter + * @author Bakaoh (with modifications) + */ +library RLPWriter { + /** + * + * Internal Functions * + * + */ + + /** + * RLP encodes a byte string. + * @param _in The byte string to encode. + * @return _out The RLP encoded string in bytes. + */ + function writeBytes(bytes memory _in) internal pure returns (bytes memory _out) { + bytes memory encoded; + + if (_in.length == 1 && uint8(_in[0]) < 128) { + encoded = _in; + } else { + encoded = abi.encodePacked(_writeLength(_in.length, 128), _in); + } + + return encoded; + } + + /** + * RLP encodes a list of RLP encoded byte byte strings. + * @param _in The list of RLP encoded byte strings. + * @return _out The RLP encoded list of items in bytes. + */ + function writeList(bytes[] memory _in) internal pure returns (bytes memory _out) { + bytes memory list = _flatten(_in); + return abi.encodePacked(_writeLength(list.length, 192), list); + } + + /** + * RLP encodes a string. + * @param _in The string to encode. + * @return _out The RLP encoded string in bytes. + */ + function writeString(string memory _in) internal pure returns (bytes memory _out) { + return writeBytes(bytes(_in)); + } + + /** + * RLP encodes an address. + * @param _in The address to encode. + * @return _out The RLP encoded address in bytes. + */ + function writeAddress(address _in) internal pure returns (bytes memory _out) { + return writeBytes(abi.encodePacked(_in)); + } + + /** + * RLP encodes a uint. + * @param _in The uint256 to encode. + * @return _out The RLP encoded uint256 in bytes. + */ + function writeUint(uint256 _in) internal pure returns (bytes memory _out) { + return writeBytes(_toBinary(_in)); + } + + /** + * RLP encodes a bool. + * @param _in The bool to encode. + * @return _out The RLP encoded bool in bytes. + */ + function writeBool(bool _in) internal pure returns (bytes memory _out) { + bytes memory encoded = new bytes(1); + encoded[0] = (_in ? bytes1(0x01) : bytes1(0x80)); + return encoded; + } + + /** + * + * Private Functions * + * + */ + + /** + * Encode the first byte, followed by the `len` in binary form if `length` is more than 55. + * @param _len The length of the string or the payload. + * @param _offset 128 if item is string, 192 if item is list. + * @return _encoded RLP encoded bytes. + */ + function _writeLength(uint256 _len, uint256 _offset) private pure returns (bytes memory _encoded) { + bytes memory encoded; + + if (_len < 56) { + encoded = new bytes(1); + encoded[0] = bytes1(uint8(_len) + uint8(_offset)); + } else { + uint256 lenLen; + uint256 i = 1; + while (_len / i != 0) { + lenLen++; + i *= 256; + } + + encoded = new bytes(lenLen + 1); + encoded[0] = bytes1(uint8(lenLen) + uint8(_offset) + 55); + for (i = 1; i <= lenLen; i++) { + encoded[i] = bytes1(uint8((_len / (256 ** (lenLen - i))) % 256)); + } + } + + return encoded; + } + + /** + * Encode integer in big endian binary form with no leading zeroes. + * @notice TODO: This should be optimized with assembly to save gas costs. + * @param _x The integer to encode. + * @return _binary RLP encoded bytes. + */ + function _toBinary(uint256 _x) private pure returns (bytes memory _binary) { + bytes memory b = abi.encodePacked(_x); + + uint256 i = 0; + for (; i < 32; i++) { + if (b[i] != 0) { + break; + } + } + + bytes memory res = new bytes(32 - i); + for (uint256 j = 0; j < res.length; j++) { + res[j] = b[i++]; + } + + return res; + } + + /** + * Copies a piece of memory to another location. + * @notice From: https://github.com/Arachnid/solidity-stringutils/blob/master/src/strings.sol. + * @param _dest Destination location. + * @param _src Source location. + * @param _len Length of memory to copy. + */ + function _memcpy(uint256 _dest, uint256 _src, uint256 _len) private pure { + uint256 dest = _dest; + uint256 src = _src; + uint256 len = _len; + + for (; len >= 32; len -= 32) { + assembly { + mstore(dest, mload(src)) + } + dest += 32; + src += 32; + } + + uint256 mask = 256 ** (32 - len) - 1; + assembly { + let srcpart := and(mload(src), not(mask)) + let destpart := and(mload(dest), mask) + mstore(dest, or(destpart, srcpart)) + } + } + + /** + * Flattens a list of byte strings into one byte string. + * @notice From: https://github.com/sammayo/solidity-rlp-encoder/blob/master/RLPEncode.sol. + * @param _list List of byte strings to flatten. + * @return _flattened The flattened byte string. + */ + function _flatten(bytes[] memory _list) private pure returns (bytes memory _flattened) { + if (_list.length == 0) { + return new bytes(0); + } + + uint256 len; + uint256 i = 0; + for (; i < _list.length; i++) { + len += _list[i].length; + } + + bytes memory flattened = new bytes(len); + uint256 flattenedPtr; + assembly { + flattenedPtr := add(flattened, 0x20) + } + + for (i = 0; i < _list.length; i++) { + bytes memory item = _list[i]; + + uint256 listPtr; + assembly { + listPtr := add(item, 0x20) + } + + _memcpy(flattenedPtr, listPtr, item.length); + flattenedPtr += _list[i].length; + } + + return flattened; + } +} diff --git a/examples/eth/test/Prover.t.sol b/examples/eth/test/Prover.t.sol new file mode 100644 index 0000000..c42c62e --- /dev/null +++ b/examples/eth/test/Prover.t.sol @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +pragma solidity 0.8.25; + +import {Test, console} from "forge-std/Test.sol"; + +import {Prover} from "../src/Prover.sol"; +import {RLPWriter} from "../src/lib/RLPWriter.sol"; + +contract ProverTest is Test { + Prover prover; + + function setUp() public { + prover = new Prover(); + } + + function test_e2eTxInclusionProof() public { + // Provide a transaction hash to be verified: + bytes32 txHash = 0x011fdfcd44319747eb06301a4cde66d9a03b69fefe8fd736fedbf1b3caa21d93; + + // Call the Rust CLI tool to generate the proof: + string[] memory cmd = new string[](7); + cmd[0] = "cargo"; + cmd[1] = "run"; + cmd[2] = "--bin"; + cmd[3] = "cli"; + cmd[4] = "tx"; + cmd[5] = vm.toString(txHash); + cmd[6] = vm.envOr("RPC_URL", string("https://cloudflare-eth.com")); + + string memory res = string(vm.ffi(cmd)); + + // Parse the proof response from JSON: + bytes[] memory parsedProof = vm.parseJsonBytesArray(res, ".proof"); + uint256 index = vm.parseJsonUint(res, ".index"); + bytes32 root = vm.parseJsonBytes32(res, ".root"); + + // Encode the proof data and the index correctly: + bytes memory key = RLPWriter.writeUint(index); + bytes memory proofData = _RLPEncodeList(parsedProof); + + // Verify the proof by checking the presence in the trie: + (bool exists, bytes memory txRLP) = prover.get(key, proofData, root); + + assertEq(exists, true); + assertEq(keccak256(txRLP), txHash); + } + + // Helper to encode a list of bytes items into RLP with each item RLP-encoded as well + function _RLPEncodeList(bytes[] memory _items) internal pure returns (bytes memory) { + bytes[] memory encodedItems = new bytes[](_items.length); + for (uint256 i = 0; i < _items.length; i++) { + encodedItems[i] = RLPWriter.writeBytes(_items[i]); + } + return RLPWriter.writeList(encodedItems); + } +}