diff --git a/tree.go b/tree.go index b6bb8d5c..04d256f8 100644 --- a/tree.go +++ b/tree.go @@ -835,6 +835,35 @@ func groupKeys(keys keylist, depth byte) []keylist { return groups } +func (n *InternalNode) GetAndLoadForProof(key []byte, resolver NodeResolverFn) ([]byte, error) { + // Each internal node that is part of the proof needs to load all its + // children since it's needed for proof openings. + childrenKey := make([]byte, n.depth+1) + copy(childrenKey, key[:n.depth]) + for i := range n.children { + if _, ok := n.children[i].(HashedNode); ok { + childrenKey[n.depth] = byte(i) + serialized, err := resolver(childrenKey) + if err != nil { + return nil, fmt.Errorf("resolving node: %s", err) + } + c, err := ParseNode(serialized, n.depth+1) + if err != nil { + return nil, fmt.Errorf("parsing resolved node: %s", err) + } + n.children[i] = c + } + } + switch child := n.children[key[n.depth+1]].(type) { + case *InternalNode: // If next node is an internal node, recurse. + return child.GetAndLoadForProof(childrenKey, resolver) + case *LeafNode: // If next node is a leaf node, return the value. + return child.Get(key, nil) + default: + panic("invalid node type") + } +} + func (n *InternalNode) GetProofItems(keys keylist, resolver NodeResolverFn) (*ProofElements, []byte, [][]byte, error) { var ( groups = groupKeys(keys, n.depth)