Skip to content

Commit

Permalink
Merkle tree construction (#504)
Browse files Browse the repository at this point in the history
* Building a merkle tree

* Obtaining merkle proof from a tree

---------

Co-authored-by: benbierens <[email protected]>
  • Loading branch information
tbekas and benbierens authored Aug 15, 2023
1 parent 39efac1 commit e860127
Show file tree
Hide file tree
Showing 5 changed files with 302 additions and 1 deletion.
189 changes: 189 additions & 0 deletions codex/merkletree/merkletree.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
## Nim-Codex
## Copyright (c) 2022 Status Research & Development GmbH
## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
## at your option.
## This file may not be copied, modified, or distributed except according to
## those terms.

import std/sequtils
import std/math
import std/bitops
import std/sugar

import pkg/libp2p
import pkg/stew/byteutils
import pkg/questionable
import pkg/questionable/results

type
MerkleHash* = MultiHash
MerkleTree* = object
leavesCount: int
nodes: seq[MerkleHash]
MerkleProof* = object
index: int
path: seq[MerkleHash]

# Tree constructed from leaves H0..H2 is
# H5=H(H3 & H4)
# / \
# H3=H(H0 & H1) H4=H(H2 & H2)
# / \ /
# H0=H(A) H1=H(B) H2=H(C)
# | | |
# A B C
#
# Memory layout is [H0, H1, H2, H3, H4, H5]
#
# Proofs of inclusion are
# - [H1, H4] for A
# - [H0, H4] for B
# - [H2, H3] for C


func computeTreeHeight(leavesCount: int): int =
if isPowerOfTwo(leavesCount):
fastLog2(leavesCount) + 1
else:
fastLog2(leavesCount) + 2

func getLowHigh(leavesCount, level: int): (int, int) =
var width = leavesCount
var low = 0
for _ in 0..<level:
low += width
width = (width + 1) div 2

(low, low + width - 1)

func getLowHigh(self: MerkleTree, level: int): (int, int) =
getLowHigh(self.leavesCount, level)

func computeTotalSize(leavesCount: int): int =
let height = computeTreeHeight(leavesCount)
getLowHigh(leavesCount, height - 1)[1] + 1

proc getWidth(self: MerkleTree, level: int): int =
let (low, high) = self.getLowHigh(level)
high - low + 1

func getChildren(self: MerkleTree, level, i: int): (MerkleHash, MerkleHash) =
let (low, high) = self.getLowHigh(level - 1)
let leftIdx = low + 2 * i
let rightIdx = min(leftIdx + 1, high)

(self.nodes[leftIdx], self.nodes[rightIdx])

func getSibling(self: MerkleTree, level, i: int): MerkleHash =
let (low, high) = self.getLowHigh(level)
if i mod 2 == 0:
self.nodes[min(low + i + 1, high)]
else:
self.nodes[low + i - 1]

proc setNode(self: var MerkleTree, level, i: int, value: MerkleHash): void =
let (low, _) = self.getLowHigh(level)
self.nodes[low + i] = value

proc root*(self: MerkleTree): MerkleHash =
self.nodes[^1]

proc len*(self: MerkleTree): int =
self.nodes.len

proc leaves*(self: MerkleTree): seq[MerkleHash] =
self.nodes[0..<self.leavesCount]

proc nodes*(self: MerkleTree): seq[MerkleHash] =
self.nodes

proc height*(self: MerkleTree): int =
computeTreeHeight(self.leavesCount)

proc `$`*(self: MerkleTree): string =
result &= "leavesCount: " & $self.leavesCount
result &= "\nnodes: " & $self.nodes

proc getProof*(self: MerkleTree, index: int): ?!MerkleProof =
if index >= self.leavesCount or index < 0:
return failure("Index " & $index & " out of range [0.." & $self.leaves.high & "]" )

var path = newSeq[MerkleHash](self.height - 1)
for level in 0..<path.len:
let i = index div (1 shl level)
path[level] = self.getSibling(level, i)

success(MerkleProof(index: index, path: path))

proc initTreeFromLeaves(leaves: openArray[MerkleHash]): ?!MerkleTree =
without mcodec =? leaves.?[0].?mcodec and
digestSize =? leaves.?[0].?size:
return failure("At least one leaf is required")

if not leaves.allIt(it.mcodec == mcodec):
return failure("All leaves must use the same codec")

let totalSize = computeTotalSize(leaves.len)
var tree = MerkleTree(leavesCount: leaves.len, nodes: newSeq[MerkleHash](totalSize))

var buf = newSeq[byte](digestSize * 2)
proc combine(l, r: MerkleHash): ?!MerkleHash =
copyMem(addr buf[0], unsafeAddr l.data.buffer[0], digestSize)
copyMem(addr buf[digestSize], unsafeAddr r.data.buffer[0], digestSize)

MultiHash.digest($mcodec, buf).mapErr(
c => newException(CatchableError, "Error calculating hash using codec " & $mcodec & ": " & $c)
)

# copy leaves
for i in 0..<tree.getWidth(0):
tree.setNode(0, i, leaves[i])

# calculate intermediate nodes
for level in 1..<tree.height:
for i in 0..<tree.getWidth(level):
let (left, right) = tree.getChildren(level, i)

without mhash =? combine(left, right), error:
return failure(error)
tree.setNode(level, i, mhash)

success(tree)

func init*(
T: type MerkleTree,
root: MerkleHash,
leavesCount: int
): MerkleTree =
let totalSize = computeTotalSize(leavesCount)
var nodes = newSeq[MerkleHash](totalSize)
nodes[^1] = root
MerkleTree(nodes: nodes, leavesCount: leavesCount)

proc init*(
T: type MerkleTree,
leaves: openArray[MerkleHash]
): ?!MerkleTree =
initTreeFromLeaves(leaves)

proc index*(self: MerkleProof): int =
self.index

proc path*(self: MerkleProof): seq[MerkleHash] =
self.path

proc `$`*(self: MerkleProof): string =
result &= "index: " & $self.index
result &= "\npath: " & $self.path

func `==`*(a, b: MerkleProof): bool =
(a.index == b.index) and (a.path == b.path)

proc init*(
T: type MerkleProof,
index: int,
path: seq[MerkleHash]
): MerkleProof =
MerkleProof(index: index, path: path)
108 changes: 108 additions & 0 deletions tests/codex/merkletree/testmerkletree.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import std/unittest
import std/bitops
import std/random
import std/sequtils
import pkg/libp2p
import codex/merkletree/merkletree
import ../helpers
import pkg/questionable/results

checksuite "merkletree":
const sha256 = multiCodec("sha2-256")
const sha512 = multiCodec("sha2-512")

proc randomHash(codec: MultiCodec = sha256): MerkleHash =
var data: array[0..31, byte]
for i in 0..31:
data[i] = rand(uint8)
return MultiHash.digest($codec, data).tryGet()

proc combine(a, b: MerkleHash, codec: MultiCodec = sha256): MerkleHash =
var buf = newSeq[byte](a.size + b.size)
for i in 0..<a.size:
buf[i] = a.data.buffer[i]
for i in 0..<b.size:
buf[i + a.size] = b.data.buffer[i]
return MultiHash.digest($codec, buf).tryGet()

var
leaves: array[0..10, MerkleHash]

setup:
for i in 0..leaves.high:
leaves[i] = randomHash()

test "tree with one leaf has expected root":
let tree = MerkleTree.init(leaves[0..0]).tryGet()

check:
tree.leaves == leaves[0..0]
tree.root == leaves[0]
tree.len == 1

test "tree with two leaves has expected root":
let
expectedRoot = combine(leaves[0], leaves[1])

let tree = MerkleTree.init(leaves[0..1]).tryGet()

check:
tree.leaves == leaves[0..1]
tree.len == 3
tree.root == expectedRoot

test "tree with three leaves has expected root":
let
expectedRoot = combine(combine(leaves[0], leaves[1]), combine(leaves[2], leaves[2]))

let tree = MerkleTree.init(leaves[0..2]).tryGet()

check:
tree.leaves == leaves[0..2]
tree.len == 6
tree.root == expectedRoot

test "tree with two leaves provides expected proofs":
let tree = MerkleTree.init(leaves[0..1]).tryGet()

let expectedProofs = [
MerkleProof.init(0, @[leaves[1]]),
MerkleProof.init(1, @[leaves[0]]),
]

check:
tree.getProof(0).tryGet() == expectedProofs[0]
tree.getProof(1).tryGet() == expectedProofs[1]

test "tree with three leaves provides expected proofs":
let tree = MerkleTree.init(leaves[0..2]).tryGet()

let expectedProofs = [
MerkleProof.init(0, @[leaves[1], combine(leaves[2], leaves[2])]),
MerkleProof.init(1, @[leaves[0], combine(leaves[2], leaves[2])]),
MerkleProof.init(2, @[leaves[2], combine(leaves[0], leaves[1])]),
]

check:
tree.getProof(0).tryGet() == expectedProofs[0]
tree.getProof(1).tryGet() == expectedProofs[1]
tree.getProof(2).tryGet() == expectedProofs[2]

test "getProof fails for index out of bounds":
let tree = MerkleTree.init(leaves[0..3]).tryGet()

check:
isErr(tree.getProof(-1))
isErr(tree.getProof(4))

test "can create MerkleTree directly from root hash":
let tree = MerkleTree.init(leaves[0], 1)

check:
tree.root == leaves[0]

test "cannot create MerkleTree from leaves with different codec":
let res = MerkleTree.init(@[randomHash(sha256), randomHash(sha512)])

check:
isErr(res)
3 changes: 3 additions & 0 deletions tests/codex/testmerkletree.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import ./merkletree/testmerkletree

{.warning[UnusedImport]: off.}
1 change: 1 addition & 0 deletions tests/testCodex.nim
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ import ./codex/testclock
import ./codex/testsystemclock
import ./codex/testvalidation
import ./codex/testasyncstreamwrapper
import ./codex/testmerkletree

{.warning[UnusedImport]: off.}
2 changes: 1 addition & 1 deletion vendor/questionable

0 comments on commit e860127

Please sign in to comment.