diff --git a/internal/round/smt_persistence_integration_test.go b/internal/round/smt_persistence_integration_test.go index a608d98..2163d5f 100644 --- a/internal/round/smt_persistence_integration_test.go +++ b/internal/round/smt_persistence_integration_test.go @@ -92,7 +92,8 @@ func TestSmtPersistenceAndRestoration(t *testing.T) { // Verify inclusion proofs work for _, leaf := range testLeaves { - merkleTreePath := restoredRm.smt.GetPath(leaf.Path) + merkleTreePath, err := restoredRm.smt.GetPath(leaf.Path) + require.NoError(t, err) require.NotNil(t, merkleTreePath, "Should be able to get Merkle path") assert.NotEmpty(t, merkleTreePath.Root, "Merkle path should have root hash") } @@ -243,7 +244,8 @@ func TestCompleteWorkflowWithRestart(t *testing.T) { path, err := commitment.RequestID.GetPath() require.NoError(t, err, "Should be able to get path from request ID") - merkleTreePath := newRm.smt.GetPath(path) + merkleTreePath, err := newRm.smt.GetPath(path) + require.NoError(t, err) require.NotNil(t, merkleTreePath, "Should be able to get Merkle path") assert.NotEmpty(t, merkleTreePath.Root, "Merkle path should have root hash") assert.NotEmpty(t, merkleTreePath.Steps, "Merkle path should have steps") diff --git a/internal/service/service.go b/internal/service/service.go index 3da406b..ddc676f 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -196,7 +196,10 @@ func (as *AggregatorService) GetInclusionProof(ctx context.Context, req *api.Get if err != nil { return nil, fmt.Errorf("failed to get path for request ID %s: %w", req.RequestID, err) } - merkleTreePath := as.roundManager.GetSMT().GetPath(path) + merkleTreePath, err := as.roundManager.GetSMT().GetPath(path) + if err != nil { + return nil, fmt.Errorf("failed to get inclusion proof for request ID %s: %w", req.RequestID, err) + } // Find the latest block that matches the current SMT root hash rootHash, err := api.NewHexBytesFromString(merkleTreePath.Root) diff --git a/internal/smt/smt.go b/internal/smt/smt.go index dabc915..368cc36 100644 --- a/internal/smt/smt.go +++ b/internal/smt/smt.go @@ -13,12 +13,15 @@ import ( var ( ErrDuplicateLeaf = errors.New("smt: duplicate leaf") ErrLeafModification = errors.New("smt: attempt to modify an existing leaf") + ErrKeyLength = errors.New("smt: invalid key length") + ErrWrongShard = errors.New("smt: key does not belong in this shard") ) type ( - // SparseMerkleTree implements a sparse merkle tree compatible with Unicity SDK + // SparseMerkleTree implements a sparse Merkle tree compatible with Unicity SDK SparseMerkleTree struct { - keyLength int // bit length of the keys in the tree + parentMode bool // true if this tree operates in "parent mode" + keyLength int // bit length of the keys in the tree algorithm api.HashAlgorithm root *NodeBranch isSnapshot bool // true if this is a snapshot, false if original tree @@ -31,24 +34,77 @@ type ( } ) -// NewSparseMerkleTree creates a new sparse merkle tree +// NewSparseMerkleTree creates a new sparse Merkle tree for a monolithic aggregator func NewSparseMerkleTree(algorithm api.HashAlgorithm, keyLength int) *SparseMerkleTree { if keyLength <= 0 { panic("SMT key length must be positive") } return &SparseMerkleTree{ + parentMode: false, keyLength: keyLength, algorithm: algorithm, - root: newRootNode(nil, nil), + root: newRootBranch(big.NewInt(1), nil, nil), isSnapshot: false, original: nil, } } +// NewChildSparseMerkleTree creates a new sparse Merkle tree for a child aggregator in sharded setup +func NewChildSparseMerkleTree(algorithm api.HashAlgorithm, keyLength int, shardID api.ShardID) *SparseMerkleTree { + if keyLength <= 0 { + panic("SMT key length must be positive") + } + if shardID <= 1 { + panic("Shard ID must be positive and have at least 2 bits") + } + path := big.NewInt(shardID) + if path.BitLen() > keyLength { + panic("Shard ID must be shorter than SMT key length") + } + return &SparseMerkleTree{ + parentMode: false, + keyLength: keyLength, + algorithm: algorithm, + root: newRootBranch(path, nil, nil), + isSnapshot: false, + original: nil, + } +} + +// NewParentSparseMerkleTree creates a new sparse Merkle tree for the parent aggregator in sharded setup +func NewParentSparseMerkleTree(algorithm api.HashAlgorithm, keyLength int) *SparseMerkleTree { + tree := NewSparseMerkleTree(algorithm, keyLength) + tree.parentMode = true + + // Populate all leaves with null hashes + // To allow the child aggregators to compute the correct root hashes + // for their respective leaves in the parent aggregator's tree, the + // parent tree is fully populated (thus not a sparse tree at all) + // It is expected that these nulls will be replaced from the stored + // state at once after the tree is initially constructed, but still + // better to ensure all the leaves exist; otherwise the hash values + // of siblings of the missing nodes would not match the structure of + // the tree and the corresponding inclusion proofs would fail to verify + tree.root.Left = populate(0b10, keyLength) // left child 1 level below + tree.root.Right = populate(0b11, keyLength) // right child 1 level below + + return tree +} + +func populate(path, levels int) branch { + if levels == 1 { + return newChildLeafBranch(big.NewInt(int64(path)), nil) + } + left := populate(0b10, levels-1) // left child 1 level below + right := populate(0b11, levels-1) // right child 1 level below + return newNodeBranch(big.NewInt(int64(path)), left, right) +} + // CreateSnapshot creates a snapshot of the current SMT state // The snapshot shares nodes with the original tree (copy-on-write) func (smt *SparseMerkleTree) CreateSnapshot() *SmtSnapshot { snapshot := &SparseMerkleTree{ + parentMode: smt.parentMode, keyLength: smt.keyLength, algorithm: smt.algorithm, root: smt.root, // Share the root initially @@ -66,11 +122,15 @@ func (snapshot *SmtSnapshot) Commit() { } // AddLeaf adds a single leaf to the snapshot +// In regular and child mode, only new leaves can be added and any attempt +// to overwrite an existing leaf is an error; in parent mode, updates are allowed func (snapshot *SmtSnapshot) AddLeaf(path *big.Int, value []byte) error { return snapshot.SparseMerkleTree.AddLeaf(path, value) } // AddLeaves adds multiple leaves to the snapshot +// In regular and child mode, only new leaves can be added and any attempt +// to overwrite an existing leaf is an error; in parent mode, updates are allowed func (snapshot *SmtSnapshot) AddLeaves(leaves []*Leaf) error { return snapshot.SparseMerkleTree.AddLeaves(leaves) } @@ -93,7 +153,7 @@ func (smt *SparseMerkleTree) CanModify() bool { // copyOnWriteRoot creates a new root if this snapshot is sharing it with the original func (smt *SparseMerkleTree) copyOnWriteRoot() *NodeBranch { if smt.original != nil && smt.root == smt.original.root { - return newRootNode(smt.root.Left, smt.root.Right) + return newRootBranch(smt.root.Path, smt.root.Left, smt.root.Right) } return smt.root } @@ -122,36 +182,58 @@ type branch interface { // LeafBranch represents a leaf node type LeafBranch struct { - Path *big.Int - Value []byte - hash *api.DataHash + Path *big.Int + Value []byte + hash *api.DataHash + isChild bool // true if this is the root hash form a child aggregator } // NodeBranch represents an internal node type NodeBranch struct { - Path *big.Int - Left branch - Right branch - hash *api.DataHash + Path *big.Int + Left branch + Right branch + hash *api.DataHash + isRoot bool // true if this is the root node } -// NewLeafBranch creates a leaf branch +// NewLeafBranch creates a regular leaf branch func newLeafBranch(path *big.Int, value []byte) *LeafBranch { return &LeafBranch{ - Path: new(big.Int).Set(path), - Value: append([]byte(nil), value...), + Path: new(big.Int).Set(path), + Value: append([]byte(nil), value...), + isChild: false, // Hash will be computed on demand } } +// NewChildLeafBranch creates a parent tree leaf containing the root hash of a child tree +func newChildLeafBranch(path *big.Int, value []byte) *LeafBranch { + if value != nil { + value = append([]byte(nil), value...) + } + return &LeafBranch{ + Path: new(big.Int).Set(path), + Value: value, + isChild: true, + // Hash will be set on demand + } +} + func (l *LeafBranch) calculateHash(hasher *api.DataHasher) *api.DataHash { if l.hash != nil { return l.hash } - pathBytes := api.BigintEncode(l.Path) - l.hash = hasher.Reset().AddData(api.CborArray(2)). - AddCborBytes(pathBytes).AddCborBytes(l.Value).GetHash() + if l.isChild { + if l.Value != nil { + l.hash = api.NewDataHash(hasher.GetAlgorithm(), l.Value) + } + } else { + pathBytes := api.BigintEncode(l.Path) + l.hash = hasher.Reset().AddData(api.CborArray(2)). + AddCborBytes(pathBytes).AddCborBytes(l.Value).GetHash() + } return l.hash } @@ -163,17 +245,24 @@ func (l *LeafBranch) isLeaf() bool { return true } -// NewRootNode creates a new root node -func newRootNode(left, right branch) *NodeBranch { - return newNodeBranch(big.NewInt(1), left, right) +// NewNodeBranch creates a regular node branch +func newNodeBranch(path *big.Int, left, right branch) *NodeBranch { + return &NodeBranch{ + Path: new(big.Int).Set(path), + Left: left, + Right: right, + isRoot: false, + // Hash will be computed on demand + } } -// NewNodeBranch creates a node branch -func newNodeBranch(path *big.Int, left, right branch) *NodeBranch { +// NewRootBranch creates a root node branch +func newRootBranch(path *big.Int, left, right branch) *NodeBranch { return &NodeBranch{ - Path: new(big.Int).Set(path), - Left: left, - Right: right, + Path: new(big.Int).Set(path), + Left: left, + Right: right, + isRoot: true, // Hash will be computed on demand } } @@ -196,8 +285,16 @@ func (n *NodeBranch) calculateHash(hasher *api.DataHasher) *api.DataHash { hasher.Reset().AddData(api.CborArray(3)) - pathBytes := api.BigintEncode(n.Path) - hasher.AddCborBytes(pathBytes) + if n.isRoot && n.Path.BitLen() > 1 { + // This is root of a child tree in sharded setup + // The path to add is the last bit of the shard ID + pos := n.Path.BitLen() - 2 + path := big.NewInt(int64(2 + n.Path.Bit(pos))) + hasher.AddCborBytes(api.BigintEncode(path)) + } else { + // In all other cases we just add the actual path + hasher.AddCborBytes(api.BigintEncode(n.Path)) + } if leftHash == nil { hasher.AddCborNull() @@ -226,7 +323,10 @@ func (n *NodeBranch) isLeaf() bool { // AddLeaf adds a single leaf to the tree func (smt *SparseMerkleTree) AddLeaf(path *big.Int, value []byte) error { if path.BitLen()-1 != smt.keyLength { - return fmt.Errorf("invalid key length %d, should be %d", path.BitLen()-1, smt.keyLength) + return ErrKeyLength + } + if calculateCommonPath(path, smt.root.Path).BitLen() != smt.root.Path.BitLen() { + return ErrWrongShard } // Implement copy-on-write for snapshots only @@ -234,8 +334,8 @@ func (smt *SparseMerkleTree) AddLeaf(path *big.Int, value []byte) error { smt.root = smt.copyOnWriteRoot() } - // TypeScript: const isRight = path & 1n; - isRight := path.Bit(0) == 1 + shifted := new(big.Int).Rsh(path, uint(smt.root.Path.BitLen()-1)) + isRight := shifted.Bit(0) == 1 var left, right branch @@ -249,13 +349,13 @@ func (smt *SparseMerkleTree) AddLeaf(path *big.Int, value []byte) error { } else { rightBranch = smt.root.Right } - newRight, err := smt.buildTree(rightBranch, path, value) + newRight, err := smt.buildTree(rightBranch, shifted, value) if err != nil { return err } right = newRight } else { - right = newLeafBranch(path, value) + right = newLeafBranch(shifted, value) } } else { if smt.root.Left != nil { @@ -266,18 +366,18 @@ func (smt *SparseMerkleTree) AddLeaf(path *big.Int, value []byte) error { } else { leftBranch = smt.root.Left } - newLeft, err := smt.buildTree(leftBranch, path, value) + newLeft, err := smt.buildTree(leftBranch, shifted, value) if err != nil { return err } left = newLeft } else { - left = newLeafBranch(path, value) + left = newLeafBranch(shifted, value) } right = smt.root.Right } - smt.root = newRootNode(left, right) + smt.root = newRootBranch(smt.root.Path, left, right) return nil } @@ -334,12 +434,12 @@ func (smt *SparseMerkleTree) findLeafInBranch(branch branch, targetPath *big.Int commonPath := calculateCommonPath(targetPath, b.Path) // Check if targetPath can be in this subtree - if commonPath.path.Cmp(targetPath) == 0 { + if commonPath.Cmp(targetPath) == 0 { return nil, fmt.Errorf("leaf not found") } // Navigate using the same logic as buildTree - shifted := new(big.Int).Rsh(targetPath, commonPath.length) + shifted := new(big.Int).Rsh(targetPath, uint(commonPath.BitLen()-1)) isRight := shifted.Bit(0) == 1 // KEY FIX: Pass the shifted path to match tree construction @@ -361,7 +461,9 @@ func (smt *SparseMerkleTree) buildTree(branch branch, remainingPath *big.Int, va // Special checks for adding a leaf that already exists in the tree if branch.isLeaf() && branch.getPath().Cmp(remainingPath) == 0 { leafBranch := branch.(*LeafBranch) - if bytes.Equal(leafBranch.Value, value) { + if leafBranch.isChild { + return newChildLeafBranch(leafBranch.Path, value), nil + } else if bytes.Equal(leafBranch.Value, value) { return nil, ErrDuplicateLeaf } else { return nil, ErrLeafModification @@ -369,59 +471,59 @@ func (smt *SparseMerkleTree) buildTree(branch branch, remainingPath *big.Int, va } commonPath := calculateCommonPath(remainingPath, branch.getPath()) - shifted := new(big.Int).Rsh(remainingPath, commonPath.length) + shifted := new(big.Int).Rsh(remainingPath, uint(commonPath.BitLen()-1)) isRight := shifted.Bit(0) == 1 - if commonPath.path.Cmp(remainingPath) == 0 { - return nil, fmt.Errorf("cannot add leaf inside branch, commonPath: '%s', remainingPath: '%s'", commonPath.path, remainingPath) + if commonPath.Cmp(remainingPath) == 0 { + return nil, fmt.Errorf("cannot add leaf inside branch, commonPath: '%s', remainingPath: '%s'", commonPath, remainingPath) } // If a leaf must be split from the middle if branch.isLeaf() { leafBranch := branch.(*LeafBranch) - if commonPath.path.Cmp(leafBranch.Path) == 0 { + if commonPath.Cmp(leafBranch.Path) == 0 { return nil, fmt.Errorf("cannot extend tree through leaf") } // TypeScript: branch.path >> commonPath.length - oldBranchPath := new(big.Int).Rsh(leafBranch.Path, commonPath.length) + oldBranchPath := new(big.Int).Rsh(leafBranch.Path, uint(commonPath.BitLen()-1)) oldBranch := newLeafBranch(oldBranchPath, leafBranch.Value) // TypeScript: remainingPath >> commonPath.length - newBranchPath := new(big.Int).Rsh(remainingPath, commonPath.length) + newBranchPath := new(big.Int).Rsh(remainingPath, uint(commonPath.BitLen()-1)) newBranch := newLeafBranch(newBranchPath, value) if isRight { - return newNodeBranch(commonPath.path, oldBranch, newBranch), nil + return newNodeBranch(commonPath, oldBranch, newBranch), nil } else { - return newNodeBranch(commonPath.path, newBranch, oldBranch), nil + return newNodeBranch(commonPath, newBranch, oldBranch), nil } } // If node branch is split in the middle nodeBranch := branch.(*NodeBranch) - if commonPath.path.Cmp(nodeBranch.Path) < 0 { - newBranchPath := new(big.Int).Rsh(remainingPath, commonPath.length) + if commonPath.Cmp(nodeBranch.Path) < 0 { + newBranchPath := new(big.Int).Rsh(remainingPath, uint(commonPath.BitLen()-1)) newBranch := newLeafBranch(newBranchPath, value) - oldBranchPath := new(big.Int).Rsh(nodeBranch.Path, commonPath.length) + oldBranchPath := new(big.Int).Rsh(nodeBranch.Path, uint(commonPath.BitLen()-1)) oldBranch := newNodeBranch(oldBranchPath, nodeBranch.Left, nodeBranch.Right) if isRight { - return newNodeBranch(commonPath.path, oldBranch, newBranch), nil + return newNodeBranch(commonPath, oldBranch, newBranch), nil } else { - return newNodeBranch(commonPath.path, newBranch, oldBranch), nil + return newNodeBranch(commonPath, newBranch, oldBranch), nil } } if isRight { - newRight, err := smt.buildTree(nodeBranch.Right, new(big.Int).Rsh(remainingPath, commonPath.length), value) + newRight, err := smt.buildTree(nodeBranch.Right, new(big.Int).Rsh(remainingPath, uint(commonPath.BitLen()-1)), value) if err != nil { return nil, err } return newNodeBranch(nodeBranch.Path, nodeBranch.Left, newRight), nil } else { - newLeft, err := smt.buildTree(nodeBranch.Left, new(big.Int).Rsh(remainingPath, commonPath.length), value) + newLeft, err := smt.buildTree(nodeBranch.Left, new(big.Int).Rsh(remainingPath, uint(commonPath.BitLen()-1)), value) if err != nil { return nil, err } @@ -429,11 +531,12 @@ func (smt *SparseMerkleTree) buildTree(branch branch, remainingPath *big.Int, va } } -func (smt *SparseMerkleTree) GetPath(path *big.Int) *api.MerkleTreePath { +func (smt *SparseMerkleTree) GetPath(path *big.Int) (*api.MerkleTreePath, error) { if path.BitLen()-1 != smt.keyLength { - // TODO: better error handling - fmt.Printf("SparseMerkleTree.GetPath(): invalid key length %d, should be %d", path.BitLen()-1, smt.keyLength) - return nil + return nil, ErrKeyLength + } + if calculateCommonPath(path, smt.root.Path).BitLen() != smt.root.Path.BitLen() { + return nil, ErrWrongShard } // Create a new hasher to ensure thread safety @@ -445,7 +548,7 @@ func (smt *SparseMerkleTree) GetPath(path *big.Int) *api.MerkleTreePath { return &api.MerkleTreePath{ Root: rootHash.ToHex(), Steps: steps, - } + }, nil } // generatePath recursively generates the Merkle tree path steps @@ -459,9 +562,13 @@ func (smt *SparseMerkleTree) generatePath(hasher *api.DataHasher, remainingPath // Create the corresponding leaf hash step currentLeaf, _ := currentNode.(*LeafBranch) path := currentLeaf.Path.String() - data := hex.EncodeToString(currentLeaf.Value) + var data *string + if currentLeaf.Value != nil { + tmp := hex.EncodeToString(currentLeaf.Value) + data = &tmp + } return []api.MerkleTreeStep{ - {Path: path, Data: &data}, + {Path: path, Data: data}, } } @@ -470,74 +577,70 @@ func (smt *SparseMerkleTree) generatePath(hasher *api.DataHasher, remainingPath panic("Unknown target branch type") } + var path *big.Int + if currentBranch.isRoot && currentBranch.Path.BitLen() > 1 { + // This is root of a child tree in sharded setup + // The path to add is the last bit of the shard ID + pos := currentBranch.Path.BitLen() - 2 + path = big.NewInt(int64(0b10 | currentBranch.Path.Bit(pos))) + } else { + // In all other cases we just add the actual path + path = currentBranch.Path + } + + var leftHash, rightHash *string + if currentBranch.Left != nil { + hash := currentBranch.Left.calculateHash(hasher) + if hash != nil { + tmp := hex.EncodeToString(hash.RawHash) + leftHash = &tmp + } + } + if currentBranch.Right != nil { + hash := currentBranch.Right.calculateHash(hasher) + if hash != nil { + tmp := hex.EncodeToString(hash.RawHash) + rightHash = &tmp + } + } + commonPath := calculateCommonPath(remainingPath, currentBranch.Path) - if commonPath.length < uint(currentBranch.Path.BitLen()-1) { + if currentBranch != smt.root && commonPath.BitLen() < currentBranch.Path.BitLen() { // Remaining path diverges or ends here - // Root node is a special case, because of its empty path // Create the corresponding 2-step proof - // No nil children in non-root nodes - leftHash := hex.EncodeToString(currentBranch.Left.calculateHash(hasher).RawHash) - rightHash := hex.EncodeToString(currentBranch.Right.calculateHash(hasher).RawHash) - // This looks weird, but see the effect in api.MerkleTreePath.Verify() return []api.MerkleTreeStep{ - {Path: "0", Data: &rightHash}, - {Path: currentBranch.Path.String(), Data: &leftHash}, + {Path: "0", Data: leftHash}, + {Path: path.String(), Data: rightHash}, } } // Trim remaining path for descending into subtree - remainingPath = new(big.Int).Rsh(remainingPath, commonPath.length) + remainingPath = new(big.Int).Rsh(remainingPath, uint(commonPath.BitLen()-1)) - var target, sibling branch + var step api.MerkleTreeStep + var steps []api.MerkleTreeStep if remainingPath.Bit(0) == 0 { - // Target in the left child - target = currentBranch.Left - sibling = currentBranch.Right - } else { - // Target in the right child - target = currentBranch.Right - sibling = currentBranch.Left - } - - if target == nil { - // Target branch empty - // This can happen only at the root node - // Create the 2-step exclusion proof - // There may be nil children here - var leftHash, rightHash *string - if currentBranch.Left != nil { - tmp := hex.EncodeToString(currentBranch.Left.calculateHash(hasher).RawHash) - leftHash = &tmp - } - if currentBranch.Right != nil { - tmp := hex.EncodeToString(currentBranch.Right.calculateHash(hasher).RawHash) - rightHash = &tmp + // Target in the left child, right child is sibling + step = api.MerkleTreeStep{Path: path.String(), Data: rightHash} + if leftHash == nil { + steps = []api.MerkleTreeStep{{Path: "0", Data: nil}} + } else { + steps = smt.generatePath(hasher, remainingPath, currentBranch.Left) } - // This looks weird, but see the effect in api.MerkleTreePath.Verify() - return []api.MerkleTreeStep{ - {Path: "0", Data: rightHash}, - {Path: "1", Data: leftHash}, + } else { + step = api.MerkleTreeStep{Path: path.String(), Data: leftHash} + // Target in the right child, left child is sibling + if rightHash == nil { + steps = []api.MerkleTreeStep{{Path: "1", Data: nil}} + } else { + steps = smt.generatePath(hasher, remainingPath, currentBranch.Right) } } - - steps := smt.generatePath(hasher, remainingPath, target) - - // Add the step for the current branch - step := api.MerkleTreeStep{ - Path: currentBranch.Path.String(), - } - if sibling != nil { - tmp := hex.EncodeToString(sibling.calculateHash(hasher).RawHash) - step.Data = &tmp - } return append(steps, step) } // calculateCommonPath computes the longest common prefix of path1 and path2 -func calculateCommonPath(path1, path2 *big.Int) struct { - length uint - path *big.Int -} { +func calculateCommonPath(path1, path2 *big.Int) *big.Int { if path1.Sign() != 1 || path2.Sign() != 1 { panic("Non-positive path value") } @@ -553,10 +656,7 @@ func calculateCommonPath(path1, path2 *big.Int) struct { res.And(res, path1) // res &= path res.Or(res, mask) // res |= mask - return struct { - length uint - path *big.Int - }{uint(pos), res} + return res } // Leaf represents a leaf to be inserted (for batch operations) @@ -572,3 +672,25 @@ func NewLeaf(path *big.Int, value []byte) *Leaf { Value: append([]byte(nil), value...), } } + +// JoinPaths joins the hash proofs from a child and parent in sharded setting +func JoinPaths(child, parent *api.MerkleTreePath) (*api.MerkleTreePath, error) { + if len(child.Root) < 4 { + return nil, fmt.Errorf("invalid child root hash format") + } + if len(parent.Steps) == 0 { + return nil, fmt.Errorf("empty parent hash steps") + } + if parent.Steps[0].Data == nil || *parent.Steps[0].Data != child.Root[4:] { + return nil, fmt.Errorf("can't join paths: child root hash does not match parent input hash") + + } + steps := make([]api.MerkleTreeStep, len(child.Steps)+len(parent.Steps)-1) + copy(steps, child.Steps) + copy(steps[len(child.Steps):], parent.Steps[1:]) + + return &api.MerkleTreePath{ + Root: parent.Root, + Steps: steps, + }, nil +} diff --git a/internal/smt/smt_debug_test.go b/internal/smt/smt_debug_test.go index a96e8c5..fd70b30 100644 --- a/internal/smt/smt_debug_test.go +++ b/internal/smt/smt_debug_test.go @@ -35,7 +35,8 @@ func TestAddLeaves_DebugInvalidPath(t *testing.T) { require.NoError(t, err, "Expected error due to invalid path") // now validate the path of request - merkleTreePath := tree.GetPath(path) + merkleTreePath, err := tree.GetPath(path) + require.NoError(t, err) require.NotNil(t, merkleTreePath, "Expected non-nil Merkle tree path for valid request ID") res, err := merkleTreePath.Verify(path) @@ -84,7 +85,8 @@ func TestAddLeaves_DebugInvalidPath(t *testing.T) { require.NoError(t, err, "Failed to create request ID") path, err := req.GetPath() require.NoError(t, err) - merkleTreePath := _smt.GetPath(path) + merkleTreePath, err := _smt.GetPath(path) + require.NoError(t, err) require.NotNil(t, merkleTreePath, "Expected non-nil Merkle tree path for valid request ID") res, err := merkleTreePath.Verify(path) diff --git a/internal/smt/smt_test.go b/internal/smt/smt_test.go index 9272e0a..63f22b4 100644 --- a/internal/smt/smt_test.go +++ b/internal/smt/smt_test.go @@ -15,12 +15,14 @@ import ( // TestSMTGetRoot test basic SMT root hash computation func TestSMTGetRoot(t *testing.T) { + // "Singleton" example from the spec t.Run("EmptyTree", func(t *testing.T) { smt := NewSparseMerkleTree(api.SHA256, 2) expected := "00001e54402898172f2948615fb17627733abbd120a85381c624ad060d28321be672" require.Equal(t, expected, smt.GetRootHashHex()) }) + // "Left Child Only" example from the spec t.Run("LeftLeaf", func(t *testing.T) { smt := NewSparseMerkleTree(api.SHA256, 2) smt.AddLeaf(big.NewInt(0b100), []byte{0x61}) @@ -29,6 +31,7 @@ func TestSMTGetRoot(t *testing.T) { require.Equal(t, expected, smt.GetRootHashHex()) }) + // "Right Child Only" example from the spec t.Run("RightLeaf", func(t *testing.T) { smt := NewSparseMerkleTree(api.SHA256, 2) smt.AddLeaf(big.NewInt(0b111), []byte{0x62}) @@ -37,6 +40,7 @@ func TestSMTGetRoot(t *testing.T) { require.Equal(t, expected, smt.GetRootHashHex()) }) + // "Two Leaves" example from the spec t.Run("TwoLeaves", func(t *testing.T) { smt := NewSparseMerkleTree(api.SHA256, 2) smt.AddLeaf(big.NewInt(0b100), []byte{0x61}) @@ -46,6 +50,7 @@ func TestSMTGetRoot(t *testing.T) { require.Equal(t, expected, smt.GetRootHashHex()) }) + // "Four Leaves" example from the spec t.Run("FourLeaves", func(t *testing.T) { smt := NewSparseMerkleTree(api.SHA256, 3) smt.AddLeaf(big.NewInt(0b1000), []byte{0x61}) @@ -58,6 +63,72 @@ func TestSMTGetRoot(t *testing.T) { }) } +func TestChildSMTGetRoot(t *testing.T) { + // Left child of the "Two Leaves, Sharded" example from the spec + t.Run("LeftOfTwoLeaves", func(t *testing.T) { + smt := NewChildSparseMerkleTree(api.SHA256, 2, 0b10) + smt.AddLeaf(big.NewInt(0b100), []byte{0x61}) + + expected := "0000256aedd9f31e69a4b0803616beab77234bae5dff519a10e519a0753be49f0534" + require.Equal(t, expected, smt.GetRootHashHex()) + }) + + // Right child of the "Two Leaves, Sharded" example from the spec + t.Run("RightOfTwoLeaves", func(t *testing.T) { + smt := NewChildSparseMerkleTree(api.SHA256, 2, 0b11) + smt.AddLeaf(big.NewInt(0b111), []byte{0x62}) + + expected := "0000e777763b4ce391c2f8acdf480dd64758bc8063a3aa5f62670a499a61d3bc7b9a" + require.Equal(t, expected, smt.GetRootHashHex()) + }) + + // Left child of the "Four Leaves, Sharded" example from the spec + t.Run("LeftOfFourLeaves", func(t *testing.T) { + smt := NewChildSparseMerkleTree(api.SHA256, 4, 0b100) + smt.AddLeaf(big.NewInt(0b10000), []byte{0x61}) + smt.AddLeaf(big.NewInt(0b11100), []byte{0x62}) + + expected := "0000a602dc13e4932c8d58196cdd34b44c44ff457323e7dcec9e5ea05d789bd28936" + require.Equal(t, expected, smt.GetRootHashHex()) + }) + + // Right child of the "Four Leaves, Sharded" example from the spec + t.Run("RightOfFourLeaves", func(t *testing.T) { + smt := NewChildSparseMerkleTree(api.SHA256, 4, 0b111) + smt.AddLeaf(big.NewInt(0b10011), []byte{0x63}) + smt.AddLeaf(big.NewInt(0b11111), []byte{0x64}) + + expected := "0000d1d4fd1c4b4e332427d726c39a2cea17ed4c59bff0458232ccb36199bb8849af" + require.Equal(t, expected, smt.GetRootHashHex()) + }) +} + +func TestParentSMTGetRoot(t *testing.T) { + // Parent of the "Two Leaves, Sharded" example from the spec + t.Run("TwoLeaves", func(t *testing.T) { + left, _ := hex.DecodeString("256aedd9f31e69a4b0803616beab77234bae5dff519a10e519a0753be49f0534") + right, _ := hex.DecodeString("e777763b4ce391c2f8acdf480dd64758bc8063a3aa5f62670a499a61d3bc7b9a") + smt := NewParentSparseMerkleTree(api.SHA256, 1) + smt.AddLeaf(big.NewInt(0b10), left) + smt.AddLeaf(big.NewInt(0b11), right) + + expected := "0000413b961d0069adfea0b4e122cf6dbf98e0a01ef7fd573d68c084ddfa03e4f9d6" + require.Equal(t, expected, smt.GetRootHashHex()) + }) + + // Parent of the "Four Leaves, Sharded" example from the spec + t.Run("FourLeaves", func(t *testing.T) { + left, _ := hex.DecodeString("a602dc13e4932c8d58196cdd34b44c44ff457323e7dcec9e5ea05d789bd28936") + right, _ := hex.DecodeString("d1d4fd1c4b4e332427d726c39a2cea17ed4c59bff0458232ccb36199bb8849af") + smt := NewParentSparseMerkleTree(api.SHA256, 2) + smt.AddLeaf(big.NewInt(0b100), left) + smt.AddLeaf(big.NewInt(0b111), right) + + expected := "0000ee27435446dd026d9f6baca2033ebffe2d29d8948eb81bf9250f7512323c6cbc" + require.Equal(t, expected, smt.GetRootHashHex()) + }) +} + // TestSMTBatchOperations tests batch functionality func TestSMTBatchOperations(t *testing.T) { t.Run("SimpleRetrievalTest", func(t *testing.T) { @@ -184,18 +255,16 @@ func TestSMTCommonPath(t *testing.T) { testCases := []struct { path1 *big.Int path2 *big.Int - expLen uint expPath *big.Int }{ - {big.NewInt(0b11), big.NewInt(0b111101111), 1, big.NewInt(0b11)}, - {big.NewInt(0b111101111), big.NewInt(0b11), 1, big.NewInt(0b11)}, - {big.NewInt(0b110010000), big.NewInt(0b100010000), 7, big.NewInt(0b10010000)}, + {big.NewInt(0b11), big.NewInt(0b111101111), big.NewInt(0b11)}, + {big.NewInt(0b111101111), big.NewInt(0b11), big.NewInt(0b11)}, + {big.NewInt(0b110010000), big.NewInt(0b100010000), big.NewInt(0b10010000)}, } for i, tc := range testCases { result := calculateCommonPath(tc.path1, tc.path2) - assert.Equal(t, tc.expLen, result.length, "Test %d: length mismatch", i) - assert.Equal(t, tc.expPath, result.path, "Test %d: path mismatch", i) + assert.Equal(t, tc.expPath, result, "Test %d: path mismatch", i) } } @@ -417,7 +486,8 @@ func TestSMTGetPath(t *testing.T) { require.NoError(t, err, "AddLeaf failed") // Test getting path for an existing leaf - merklePath := smt.GetPath(path) + merklePath, err := smt.GetPath(path) + require.NoError(t, err) require.NotNil(t, merklePath, "GetPath should return a path") require.Equal(t, smt.GetRootHashHex(), merklePath.Root, "Root hash should match expected value") require.NotNil(t, merklePath.Steps, "Steps should not be nil") @@ -443,7 +513,8 @@ func TestSMTGetPath(t *testing.T) { require.NoError(t, err, "AddLeaf failed") // Test getting path for an existing leaf - path := smt.GetPath(big.NewInt(0b10)) + path, err := smt.GetPath(big.NewInt(0b10)) + require.NoError(t, err) require.NotNil(t, path, "GetPath should return a path") require.NotEmpty(t, path.Root, "Root hash should not be empty") require.NotNil(t, path.Steps, "Steps should not be nil") @@ -463,7 +534,8 @@ func TestSMTGetPath(t *testing.T) { require.NoError(t, err, "AddLeaf failed") // Test getting path for a non-existent leaf - path := smt.GetPath(big.NewInt(0b11)) + path, err := smt.GetPath(big.NewInt(0b11)) + require.NoError(t, err) require.NotNil(t, path, "GetPath should return a path even for non-existent leaves") require.NotEmpty(t, path.Root, "Root hash should not be empty") require.NotNil(t, path.Steps, "Steps should not be nil") @@ -482,7 +554,8 @@ func TestSMTGetPath(t *testing.T) { require.NoError(t, err, "AddLeaf failed") // Test path structure - path := smt.GetPath(big.NewInt(0b10)) + path, err := smt.GetPath(big.NewInt(0b10)) + require.NoError(t, err) require.NotNil(t, path, "GetPath should return a path") // Verify step structure @@ -498,7 +571,8 @@ func TestSMTGetPath(t *testing.T) { smt := NewSparseMerkleTree(api.SHA256, 1) // Test getting path from empty tree - path := smt.GetPath(big.NewInt(0b10)) + path, err := smt.GetPath(big.NewInt(0b10)) + require.NoError(t, err) require.NotNil(t, path, "GetPath should return a path even for empty tree") require.NotEmpty(t, path.Root, "Root hash should not be empty even for empty tree") require.NotNil(t, path.Steps, "Steps should not be nil") @@ -519,7 +593,8 @@ func TestSMTGetPathComprehensive(t *testing.T) { require.NoError(t, err, "AddLeaf failed") // Get path for the leaf - path := smt.GetPath(leafPath) + path, err := smt.GetPath(leafPath) + require.NoError(t, err) require.NotNil(t, path, "GetPath should return a path") require.Equal(t, smt.GetRootHashHex(), path.Root, "Path root should match tree root") @@ -545,13 +620,15 @@ func TestSMTGetPathComprehensive(t *testing.T) { require.NoError(t, err, "AddLeaf 2 failed") // Get path for first leaf - merkPath1 := smt.GetPath(path1) + merkPath1, err := smt.GetPath(path1) + require.NoError(t, err) require.NotNil(t, merkPath1, "GetPath should return a path") require.Equal(t, smt.GetRootHashHex(), merkPath1.Root, "Path root should match tree root") require.NotEmpty(t, merkPath1.Steps, "Should have steps") // Get path for second leaf - merkPath2 := smt.GetPath(path2) + merkPath2, err := smt.GetPath(path2) + require.NoError(t, err) require.NotNil(t, merkPath2, "GetPath should return a path") require.Equal(t, smt.GetRootHashHex(), merkPath2.Root, "Path root should match tree root") require.NotEmpty(t, merkPath2.Steps, "Should have steps") @@ -585,8 +662,8 @@ func TestSMTGetPathComprehensive(t *testing.T) { // Try to get path for non-existent leaf nonExistentPath := big.NewInt(0b111) - merkPath := smt.GetPath(nonExistentPath) - + merkPath, err := smt.GetPath(nonExistentPath) + require.NoError(t, err) require.NotNil(t, merkPath, "GetPath should return a path even for non-existent paths") require.Equal(t, smt.GetRootHashHex(), merkPath.Root, "Path root should match tree root") require.NotEmpty(t, merkPath.Steps, "Should have steps even for non-existent path") @@ -624,7 +701,8 @@ func TestSMTGetPathComprehensive(t *testing.T) { rootHash := smt.GetRootHashHex() for i, path := range testPaths { - merkPath := smt.GetPath(path) + merkPath, err := smt.GetPath(path) + require.NoError(t, err) require.NotNil(t, merkPath, "GetPath should return a path for leaf %d", i) require.Equal(t, rootHash, merkPath.Root, "All paths should have same root") require.NotEmpty(t, merkPath.Steps, "Path should have steps for leaf %d", i) @@ -652,14 +730,15 @@ func TestSMTGetPathComprehensive(t *testing.T) { smt := NewSparseMerkleTree(api.SHA256, 2) // Get path from empty tree - path := smt.GetPath(big.NewInt(0b101)) + path, err := smt.GetPath(big.NewInt(0b101)) + require.NoError(t, err) require.NotNil(t, path, "GetPath should return a path even for empty tree") require.NotEmpty(t, path.Root, "Root should not be empty even for empty tree") require.NotNil(t, path.Steps, "Steps should not be nil") require.Len(t, path.Steps, 2, "Should have two steps") step0 := path.Steps[0] - require.Equal(t, "0", step0.Path, "Input step path") + require.Equal(t, "1", step0.Path, "Input step path") require.Nil(t, step0.Data, "Empty tree step should have no data") step1 := path.Steps[1] @@ -690,7 +769,8 @@ func TestSMTGetPathComprehensive(t *testing.T) { // Get paths and validate structure for _, leaf := range testLeaves { - merkPath := smt.GetPath(leaf.path) + merkPath, err := smt.GetPath(leaf.path) + require.NoError(t, err) require.NotNil(t, merkPath, "GetPath should return a path") // Validate path structure for verification compatibility @@ -727,10 +807,10 @@ func TestSMTGetPathComprehensive(t *testing.T) { require.NoError(t, err, "AddLeaf failed") // Get paths multiple times and verify consistency - merkPath1a := smt.GetPath(path1) - merkPath1b := smt.GetPath(path1) - merkPath2a := smt.GetPath(path2) - merkPath2b := smt.GetPath(path2) + merkPath1a, _ := smt.GetPath(path1) + merkPath1b, _ := smt.GetPath(path1) + merkPath2a, _ := smt.GetPath(path2) + merkPath2b, _ := smt.GetPath(path2) // Same path should return identical results require.Equal(t, merkPath1a.Root, merkPath1b.Root, "Same path should have same root") @@ -777,7 +857,8 @@ func TestSMTGetPathComprehensive(t *testing.T) { require.NoError(t, err, "AddLeaf failed for %s", tc.name) // Get path - merkPath := smt.GetPath(path) + merkPath, err := smt.GetPath(path) + require.NoError(t, err) require.NotNil(t, merkPath, "GetPath should return a path for %s", tc.name) // Verify the path representation in steps @@ -1191,32 +1272,119 @@ func TestSMTOrderDependencyBatch(t *testing.T) { assert.Equal(t, hash1, hash2, "SMT additions should be order-independent") } -// TestSMTAddingNodeUnderLeaf - Test that the SMT does not allow adding child nodes under existing leaves -// TODO: this is now a test that the tree rejects insertions with wrong key length -func TestSMTAddingNodeUnderLeaf(t *testing.T) { - smt1 := NewSparseMerkleTree(api.SHA256, 1) - require.NoError(t, smt1.AddLeaf(big.NewInt(2), []byte("leaf_1"))) - require.Error(t, smt1.AddLeaf(big.NewInt(4), []byte("child_under_leaf_1")), "SMT should not allow adding child nodes under leaves") +// TestSMTKeyLength - Test that the SMT does not allow adding leaves with wrong key lengths +func TestSMTKeyLength(t *testing.T) { + smt := NewSparseMerkleTree(api.SHA256, 3) - smt2 := NewSparseMerkleTree(api.SHA256, 1) - leaves2 := []*Leaf{ - {Path: big.NewInt(2), Value: []byte("leaf_1")}, - {Path: big.NewInt(4), Value: []byte("child_under_leaf_1")}, + require.Error(t, smt.AddLeaf(big.NewInt(0b100), []byte("leaf_1")), "SMT should not allow adding leaves with too short keys") + + leaves1 := []*Leaf{ + {Path: big.NewInt(0b1000), Value: []byte("leaf_1")}, // OK + {Path: big.NewInt(0b111), Value: []byte("leaf_2")}, // too short key } - require.Error(t, smt2.AddLeaves(leaves2), "SMT should not allow adding child nodes under leaves, even in a batch") -} + require.Error(t, smt.AddLeaves(leaves1), "SMT should not allow adding leaves with too short keys, even in a batch") -// TestSMTAddingLeafAboveNode - Test that the SMT does not allow adding leaves above existing nodes -// TODO: this is now a test that the tree rejects insertions with wrong key length -func TestSMTAddingLeafAboveNode(t *testing.T) { - smt1 := NewSparseMerkleTree(api.SHA256, 2) - require.NoError(t, smt1.AddLeaf(big.NewInt(4), []byte("leaf_1"))) - require.Error(t, smt1.AddLeaf(big.NewInt(2), []byte("node_above_leaf_1")), "SMT should not allow adding leaves above existing nodes") + require.Error(t, smt.AddLeaf(big.NewInt(0b10000), []byte("leaf_1")), "SMT should not allow adding leaves with too long keys") - smt2 := NewSparseMerkleTree(api.SHA256, 2) leaves2 := []*Leaf{ - {Path: big.NewInt(4), Value: []byte("leaf_1")}, - {Path: big.NewInt(2), Value: []byte("node_above_leaf_1")}, + {Path: big.NewInt(0b1000), Value: []byte("leaf_1")}, // OK + {Path: big.NewInt(0b11111), Value: []byte("leaf_2")}, // too long key } - require.Error(t, smt2.AddLeaves(leaves2), "SMT should not allow adding leaves above existing nodes, even in a batch") + require.Error(t, smt.AddLeaves(leaves2), "SMT should not allow adding leaves with too long keys, even in a batch") +} + +func TestJoinPaths(t *testing.T) { + // "Two Leaves, Sharded" example from the spec + t.Run("TwoLeaves", func(t *testing.T) { + left := NewChildSparseMerkleTree(api.SHA256, 2, 0b10) + left.AddLeaf(big.NewInt(0b100), []byte{0x61}) + + right := NewChildSparseMerkleTree(api.SHA256, 2, 0b11) + right.AddLeaf(big.NewInt(0b111), []byte{0x62}) + + parent := NewParentSparseMerkleTree(api.SHA256, 1) + parent.AddLeaf(big.NewInt(0b10), left.GetRootHash()[2:]) + parent.AddLeaf(big.NewInt(0b11), right.GetRootHash()[2:]) + + leftChild, _ := left.GetPath(big.NewInt(0b100)) + leftParent, _ := parent.GetPath(big.NewInt(0b10)) + leftPath, err := JoinPaths(leftChild, leftParent) + assert.Nil(t, err) + assert.NotNil(t, leftPath) + leftRes, err := leftPath.Verify(big.NewInt(0b100)) + assert.Nil(t, err) + assert.NotNil(t, leftRes) + assert.True(t, leftRes.PathValid) + assert.True(t, leftRes.PathIncluded) + + rightChild, _ := right.GetPath(big.NewInt(0b111)) + rightParent, _ := parent.GetPath(big.NewInt(0b11)) + rightPath, err := JoinPaths(rightChild, rightParent) + assert.Nil(t, err) + assert.NotNil(t, rightPath) + rightRes, err := rightPath.Verify(big.NewInt(0b111)) + assert.Nil(t, err) + assert.NotNil(t, rightRes) + assert.True(t, rightRes.PathValid) + assert.True(t, rightRes.PathIncluded) + }) + + // "Four Leaves, Sharded" example from the spec + t.Run("FourLeaves", func(t *testing.T) { + left := NewChildSparseMerkleTree(api.SHA256, 4, 0b100) + left.AddLeaf(big.NewInt(0b10000), []byte{0x61}) + left.AddLeaf(big.NewInt(0b11100), []byte{0x62}) + + right := NewChildSparseMerkleTree(api.SHA256, 4, 0b111) + right.AddLeaf(big.NewInt(0b10011), []byte{0x63}) + right.AddLeaf(big.NewInt(0b11111), []byte{0x64}) + + parent := NewParentSparseMerkleTree(api.SHA256, 2) + parent.AddLeaf(big.NewInt(0b100), left.GetRootHash()[2:]) + parent.AddLeaf(big.NewInt(0b111), right.GetRootHash()[2:]) + + child1, _ := left.GetPath(big.NewInt(0b10000)) + parent1, _ := parent.GetPath(big.NewInt(0b100)) + path1, err := JoinPaths(child1, parent1) + assert.Nil(t, err) + assert.NotNil(t, path1) + res1, err := path1.Verify(big.NewInt(0b10000)) + assert.Nil(t, err) + assert.NotNil(t, res1) + assert.True(t, res1.PathValid) + assert.True(t, res1.PathIncluded) + + child2, _ := left.GetPath(big.NewInt(0b11100)) + parent2, _ := parent.GetPath(big.NewInt(0b100)) + path2, err := JoinPaths(child2, parent2) + assert.Nil(t, err) + assert.NotNil(t, path2) + res2, err := path2.Verify(big.NewInt(0b11100)) + assert.Nil(t, err) + assert.NotNil(t, res2) + assert.True(t, res2.PathValid) + assert.True(t, res2.PathIncluded) + + child3, _ := right.GetPath(big.NewInt(0b10011)) + parent3, _ := parent.GetPath(big.NewInt(0b111)) + path3, err := JoinPaths(child3, parent3) + assert.Nil(t, err) + assert.NotNil(t, path3) + res3, err := path3.Verify(big.NewInt(0b10011)) + assert.Nil(t, err) + assert.NotNil(t, res3) + assert.True(t, res3.PathValid) + assert.True(t, res3.PathIncluded) + + child4, _ := right.GetPath(big.NewInt(0b11111)) + parent4, _ := parent.GetPath(big.NewInt(0b111)) + path4, err := JoinPaths(child4, parent4) + assert.Nil(t, err) + assert.NotNil(t, path4) + res4, err := path4.Verify(big.NewInt(0b11111)) + assert.Nil(t, err) + assert.NotNil(t, res4) + assert.True(t, res4.PathValid) + assert.True(t, res4.PathIncluded) + }) } diff --git a/internal/smt/thread_safe_smt.go b/internal/smt/thread_safe_smt.go index c3478aa..f36b2ec 100644 --- a/internal/smt/thread_safe_smt.go +++ b/internal/smt/thread_safe_smt.go @@ -65,7 +65,7 @@ func (ts *ThreadSafeSMT) GetLeaf(path *big.Int) (*LeafBranch, error) { // GetPath generates a Merkle tree path for the given path // This is a read operation and allows concurrent access -func (ts *ThreadSafeSMT) GetPath(path *big.Int) *api.MerkleTreePath { +func (ts *ThreadSafeSMT) GetPath(path *big.Int) (*api.MerkleTreePath, error) { ts.rwMux.RLock() defer ts.rwMux.RUnlock() return ts.smt.GetPath(path) diff --git a/internal/smt/thread_safe_smt_snapshot.go b/internal/smt/thread_safe_smt_snapshot.go index ee3c89b..2e6e145 100644 --- a/internal/smt/thread_safe_smt_snapshot.go +++ b/internal/smt/thread_safe_smt_snapshot.go @@ -65,7 +65,7 @@ func (tss *ThreadSafeSmtSnapshot) GetRootHash() string { return tss.snapshot.GetRootHashHex() } -func (tss *ThreadSafeSmtSnapshot) GetPath(path *big.Int) *api.MerkleTreePath { +func (tss *ThreadSafeSmtSnapshot) GetPath(path *big.Int) (*api.MerkleTreePath, error) { tss.rwMux.RLock() defer tss.rwMux.RUnlock() diff --git a/internal/smt/thread_safe_smt_snapshot_test.go b/internal/smt/thread_safe_smt_snapshot_test.go index 547af66..59f8d1f 100644 --- a/internal/smt/thread_safe_smt_snapshot_test.go +++ b/internal/smt/thread_safe_smt_snapshot_test.go @@ -167,7 +167,8 @@ func TestThreadSafeSMTSnapshot(t *testing.T) { assert.Equal(t, value, leaf.Value, "Retrieved leaf value should match") // Test path generation on original SMT - merkleTreePath := threadSafeSMT.GetPath(path) + merkleTreePath, err := threadSafeSMT.GetPath(path) + require.NoError(t, err) assert.NotNil(t, merkleTreePath, "Should be able to get Merkle tree path from original SMT") assert.NotEmpty(t, merkleTreePath.Root, "Root should not be empty") assert.NotEmpty(t, merkleTreePath.Steps, "Steps should not be empty") diff --git a/pkg/api/merkle_tree_path_verify_test.go b/pkg/api/merkle_tree_path_verify_test.go index d1a67bf..0315339 100644 --- a/pkg/api/merkle_tree_path_verify_test.go +++ b/pkg/api/merkle_tree_path_verify_test.go @@ -28,7 +28,8 @@ func TestMerkleTreePathVerify(t *testing.T) { err := tree.AddLeaves([]*smt.Leaf{leaf}) require.NoError(t, err) - path := tree.GetPath(big.NewInt(42)) + path, err := tree.GetPath(big.NewInt(42)) + require.NoError(t, err) require.NotNil(t, path) result, err := path.Verify(big.NewInt(42)) @@ -49,7 +50,8 @@ func TestMerkleTreePathVerify(t *testing.T) { // Verify both paths for _, leafPath := range []int64{10, 12} { - path := tree.GetPath(big.NewInt(leafPath)) + path, err := tree.GetPath(big.NewInt(leafPath)) + require.NoError(t, err) require.NotNil(t, path) result, err := path.Verify(big.NewInt(leafPath)) @@ -75,7 +77,8 @@ func TestMerkleTreePathVerify(t *testing.T) { // Verify each path for _, p := range paths { - path := tree.GetPath(big.NewInt(0x1000000000000 + p)) + path, err := tree.GetPath(big.NewInt(0x1000000000000 + p)) + require.NoError(t, err) require.NotNil(t, path) result, err := path.Verify(big.NewInt(0x1000000000000 + p)) @@ -101,7 +104,8 @@ func TestMerkleTreePathVerify(t *testing.T) { require.NoError(t, err) // Verify transfer path - path := tree.GetPath(transferPath) + path, err := tree.GetPath(transferPath) + require.NoError(t, err) require.NotNil(t, path) result, err := path.Verify(transferPath) @@ -110,7 +114,8 @@ func TestMerkleTreePathVerify(t *testing.T) { require.True(t, result.PathValid, "Transfer path should be valid") // Verify mint path - pathMint := tree.GetPath(mintPath) + pathMint, err := tree.GetPath(mintPath) + require.NoError(t, err) require.NotNil(t, pathMint) resultMint, err := pathMint.Verify(mintPath) @@ -132,7 +137,8 @@ func TestMerkleTreePathVerify(t *testing.T) { // a valid path showing where that leaf would be inserted. Since leaf 1000 goes // left (bit 0 = 0) and 999 would go right (bit 0 = 1), we get a path to the // empty right branch with the left subtree as sibling. - path := tree.GetPath(big.NewInt(999)) + path, err := tree.GetPath(big.NewInt(999)) + require.NoError(t, err) require.NotNil(t, path) // When we verify this path with requestId 999: @@ -155,7 +161,8 @@ func TestMerkleTreePathVerify(t *testing.T) { require.NoError(t, err) // Get path for 5 - path5 := tree.GetPath(big.NewInt(0x1000 + 5)) + path5, err := tree.GetPath(big.NewInt(0x1000 + 5)) + require.NoError(t, err) // Try to verify with wrong requestId result, err := path5.Verify(big.NewInt(0x1000 + 15)) @@ -192,7 +199,8 @@ func TestMerkleTreePathVerify(t *testing.T) { require.NoError(t, err) for _, p := range tc.paths { - path := tree.GetPath(big.NewInt(p)) + path, err := tree.GetPath(big.NewInt(p)) + require.NoError(t, err) result, err := path.Verify(big.NewInt(p)) require.NoError(t, err) require.True(t, result.PathIncluded && result.PathValid, @@ -229,13 +237,15 @@ func TestMerkleTreePathVerify(t *testing.T) { require.NoError(t, err) // Verify paths - treePath1 := tree.GetPath(path1) + treePath1, err := tree.GetPath(path1) + require.NoError(t, err) result1, err := treePath1.Verify(path1) require.NoError(t, err) require.True(t, result1.PathIncluded && result1.PathValid, "RequestID1 path should be valid") - treePath2 := tree.GetPath(path2) + treePath2, err := tree.GetPath(path2) + require.NoError(t, err) result2, err := treePath2.Verify(path2) require.NoError(t, err) require.True(t, result2.PathIncluded && result2.PathValid, @@ -253,7 +263,8 @@ func TestMerkleTreePathVerify(t *testing.T) { // Verify all previously added leaves still work for j := int64(1); j <= i; j++ { - path := tree.GetPath(big.NewInt(0x100000 + j*100)) + path, err := tree.GetPath(big.NewInt(0x100000 + j*100)) + require.NoError(t, err) result, err := path.Verify(big.NewInt(0x100000 + j*100)) require.NoError(t, err) require.True(t, result.PathIncluded && result.PathValid, @@ -317,7 +328,8 @@ func TestMerkleTreePathVerifyDuplicates(t *testing.T) { require.Error(t, err) // Verify the original value is still there - path := tree.GetPath(big.NewInt(100)) + path, err := tree.GetPath(big.NewInt(100)) + require.NoError(t, err) result, err := path.Verify(big.NewInt(100)) require.NoError(t, err) require.True(t, result.PathIncluded && result.PathValid, @@ -343,7 +355,8 @@ func TestMerkleTreePathVerifyAlternateAlgorithm(t *testing.T) { require.Equal(t, root[:4], fmt.Sprintf("%04x", algo)) for _, leaf := range leaves { - path := tree.GetPath(leaf.Path) + path, err := tree.GetPath(leaf.Path) + require.NoError(t, err) require.Equal(t, root, path.Root) res, err := path.Verify(leaf.Path) require.NoError(t, err) diff --git a/pkg/api/smt.go b/pkg/api/smt.go index cb7ace0..5f9d446 100644 --- a/pkg/api/smt.go +++ b/pkg/api/smt.go @@ -42,7 +42,7 @@ func (m *MerkleTreePath) Verify(requestId *big.Int) (*PathVerificationResult, er } // The "running totals" as we go through the hashing steps - currentPath := big.NewInt(1) + var currentPath *big.Int var currentData *[]byte for i, step := range m.Steps { @@ -61,7 +61,7 @@ func (m *MerkleTreePath) Verify(requestId *big.Int) (*PathVerificationResult, er } if i == 0 { - if stepPath.Sign() > 0 { + if stepPath.BitLen() >= 2 { // First step, normal case: data is the value in the leaf, apply the leaf hashing rule hasher.Reset().AddData(CborArray(2)) hasher.AddCborBytes(BigintEncode(stepPath)) @@ -73,8 +73,10 @@ func (m *MerkleTreePath) Verify(requestId *big.Int) (*PathVerificationResult, er currentData = &hasher.GetHash().RawHash } else { // First step, special case: data is the "our branch" hash value for the next step + // Note that in this case stepPath is a "naked" direction bit currentData = stepData } + currentPath = stepPath } else { // All subsequent steps: apply the non-leaf hashing rule var left, right *[]byte @@ -101,19 +103,21 @@ func (m *MerkleTreePath) Verify(requestId *big.Int) (*PathVerificationResult, er hasher.AddCborBytes(*right) } currentData = &hasher.GetHash().RawHash - } - if stepPath.Sign() > 0 { + // Initialization for when currentPath is a "naked" direction bit + if currentPath.BitLen() < 2 { + currentPath = big.NewInt(1) + } // Append step path bits to current path pathLen := stepPath.BitLen() - 1 - stepPath.SetBit(stepPath, pathLen, 0) + mask := new(big.Int).SetBit(stepPath, pathLen, 0) currentPath.Lsh(currentPath, uint(pathLen)) - currentPath.Or(currentPath, stepPath) + currentPath.Or(currentPath, mask) } } pathValid := currentData != nil && m.Root == NewDataHash(hasher.algorithm, *currentData).ToHex() - pathIncluded := requestId.Cmp(currentPath) == 0 + pathIncluded := currentPath != nil && requestId.Cmp(currentPath) == 0 return &PathVerificationResult{ PathValid: pathValid, diff --git a/pkg/api/types.go b/pkg/api/types.go index a1e20c4..aa7c1c1 100644 --- a/pkg/api/types.go +++ b/pkg/api/types.go @@ -15,6 +15,7 @@ import ( // Basic types for API type StateHash = ImprintHexString type TransactionHash = ImprintHexString +type ShardID = int64 // Authenticator represents the authentication data for a commitment type Authenticator struct {