From 891ee51431641021f7cc961833e5b8ac9e7060ba Mon Sep 17 00:00:00 2001 From: Scotty <66335769+ScottyPoi@users.noreply.github.com> Date: Mon, 11 Mar 2024 07:26:59 -0600 Subject: [PATCH] Trie: add partialPath parameter to trie.findPath() (#3305) * trie: add optional "partialPath" parameter to findPath * trie: use partialPath input in findPath stack * trie: start findPath walk from end of partialPath * trie: identify starting point in debug log * trie: test findPath with partial * trie: test findPath on secure trie --------- Co-authored-by: Holger Drewes --- packages/trie/src/trie.ts | 27 +++++++-- packages/trie/test/trie/findPath.spec.ts | 70 ++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 4 deletions(-) create mode 100644 packages/trie/test/trie/findPath.spec.ts diff --git a/packages/trie/src/trie.ts b/packages/trie/src/trie.ts index a47b458bc8..ea2d9a14d3 100644 --- a/packages/trie/src/trie.ts +++ b/packages/trie/src/trie.ts @@ -596,11 +596,23 @@ export class Trie { * @param key - the search key * @param throwIfMissing - if true, throws if any nodes are missing. Used for verifying proofs. (default: false) */ - async findPath(key: Uint8Array, throwIfMissing = false): Promise { + async findPath( + key: Uint8Array, + throwIfMissing = false, + partialPath: { + stack: TrieNode[] + } = { + stack: [], + } + ): Promise { const targetKey = bytesToNibbles(key) const keyLen = targetKey.length const stack: TrieNode[] = Array.from({ length: keyLen }) let progress = 0 + for (let i = 0; i < partialPath.stack.length - 1; i++) { + stack[i] = partialPath.stack[i] + progress += stack[i] instanceof BranchNode ? 1 : (stack[i]).keyLength() + } this.DEBUG && this.debug(`Target (${targetKey.length}): [${targetKey}]`, ['FIND_PATH']) let result: Path | null = null @@ -672,10 +684,17 @@ export class Trie { walkController.allChildren(node, keyProgress) } } - + const startingNode = partialPath.stack[partialPath.stack.length - 1] + const start = startingNode !== undefined ? this.hash(startingNode?.serialize()) : this.root() try { - this.DEBUG && this.debug(`Walking trie from root: ${bytesToHex(this.root())}`, ['FIND_PATH']) - await this.walkTrie(this.root(), onFound) + this.DEBUG && + this.debug( + `Walking trie from ${startingNode === undefined ? 'ROOT' : 'NODE'}: ${bytesToHex( + start as Uint8Array + )}`, + ['FIND_PATH'] + ) + await this.walkTrie(start, onFound) } catch (error: any) { if (error.message !== 'Missing node in DB' || throwIfMissing) { throw error diff --git a/packages/trie/test/trie/findPath.spec.ts b/packages/trie/test/trie/findPath.spec.ts new file mode 100644 index 0000000000..4bcf859fe6 --- /dev/null +++ b/packages/trie/test/trie/findPath.spec.ts @@ -0,0 +1,70 @@ +import { randomBytes } from '@ethereumjs/util' +import { assert, describe, it } from 'vitest' + +import { Trie } from '../../src/index.js' + +describe('TRIE > findPath', async () => { + const keys = Array.from({ length: 200 }, () => randomBytes(8)) + const trie = new Trie() + for (const [i, k] of keys.entries()) { + await trie.put(k, Uint8Array.from([i, i])) + } + const rootNode = await trie.lookupNode(trie.root()) + for (const [idx, k] of keys.slice(0, 10).entries()) { + const val = await trie.get(k) + it('should find values for key', async () => { + assert.deepEqual(val, Uint8Array.from([idx, idx])) + }) + trie['debug']('FIND PATH ORIGINAL:' + '-'.repeat(20)) + const path = await trie.findPath(k) + it('should find path for key', async () => { + assert.isNotNull(path.node) + assert.deepEqual(path.stack[0], rootNode) + assert.deepEqual(path.node?.value(), Uint8Array.from([idx, idx])) + }) + trie['debug'](`FINDING PARTIAL PATHS: ` + path.stack.length + '-'.repeat(20)) + for (let i = 1; i <= path.stack.length - 1; i++) { + trie['debug']('FIND PATH PARTIAL: ' + i + '-'.repeat(20)) + const pathFromPartial = await trie.findPath(k, false, { stack: path.stack.slice(0, i) }) + it(`should find path for key from partial stack (${i}/${path.stack.length})`, async () => { + assert.deepEqual(path, pathFromPartial) + assert.isNotNull(pathFromPartial.node) + assert.deepEqual(pathFromPartial.stack[0], rootNode) + assert.deepEqual(pathFromPartial.node?.value(), Uint8Array.from([idx, idx])) + assert.equal(path.stack.length, pathFromPartial.stack.length) + }) + } + } +}) +describe('TRIE (secure) > findPath', async () => { + const keys = Array.from({ length: 1000 }, () => randomBytes(20)) + const trie = new Trie({ useKeyHashing: true }) + for (const [i, k] of keys.entries()) { + await trie.put(k, Uint8Array.from([i, i])) + } + const rootNode = await trie.lookupNode(trie.root()) + for (const [idx, k] of keys.slice(0, 10).entries()) { + const val = await trie.get(k) + it('should find value for key', async () => { + assert.deepEqual(val, Uint8Array.from([idx, idx])) + }) + const path = await trie.findPath(trie['hash'](k)) + it('should find path for key', async () => { + assert.isNotNull(path.node) + assert.deepEqual(path.stack[0], rootNode) + assert.deepEqual(path.node?.value(), Uint8Array.from([idx, idx])) + }) + for (let i = 2; i <= path.stack.length - 1; i++) { + const pathFromPartial = await trie.findPath(trie['hash'](k), false, { + stack: path.stack.slice(0, i), + }) + it(`should find path for key from partial stack (${i}/${path.stack.length})`, async () => { + assert.deepEqual(path, pathFromPartial) + assert.isNotNull(pathFromPartial.node) + assert.deepEqual(pathFromPartial.stack[0], rootNode) + assert.deepEqual(pathFromPartial.node?.value(), Uint8Array.from([idx, idx])) + assert.equal(path.stack.length, pathFromPartial.stack.length) + }) + } + } +})