From 231ca20b659d12b9291c0f5e898e2ca51d6762f2 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut <134911583+absolutelightning@users.noreply.github.com> Date: Tue, 21 May 2024 13:52:47 +0530 Subject: [PATCH] Fix Test and Add tests for Track Mutate + Add LRU Cache for writeNode (#6) * reverse iterator init * some fixes * fix reverse iterator seaklowerbound * fix tests * fix track channels * longest prefix on txn * some fixes * revisit * todo revisit * major code refactor * some minor fixes * add lru * fix bugs * added walk func * add more tests * some prog * some progress track mutate * some progress * fix tests --- go.mod | 1 + go.sum | 2 + helpers.go | 64 +-- iterator.go | 17 +- iterator_test.go | 6 +- node_16.go | 1 + node_256.go | 1 + node_4.go | 1 + node_48.go | 1 + node_leaf.go | 1 + path_iter_test.go | 2 +- reverse_iterator_test.go | 16 +- tree.go | 260 ++++-------- tree_test.go | 845 ++++++++++++++++++++++++++++++++++----- txn.go | 454 +++++++++++++++++---- 15 files changed, 1254 insertions(+), 418 deletions(-) diff --git a/go.mod b/go.mod index 217756d..8ed1507 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.21 require ( github.com/hashicorp/go-uuid v1.0.3 + github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/stretchr/testify v1.9.0 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 ) diff --git a/go.sum b/go.sum index c5f76ed..153bb13 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= diff --git a/helpers.go b/helpers.go index c24762e..5468adf 100644 --- a/helpers.go +++ b/helpers.go @@ -28,44 +28,6 @@ func leafMatches(nodeKey []byte, key []byte) int { return bytes.Compare(nodeKey, key) } -func (t *RadixTree[T]) makeLeaf(key []byte, value T) Node[T] { - // Allocate memory for the leaf node - l := t.allocNode(leafType) - - if l == nil { - return nil - } - - // Set the value and key length - l.setValue(value) - l.setKeyLen(uint32(len(key))) - l.setKey(key) - return l -} - -func (t *RadixTree[T]) allocNode(ntype nodeType) Node[T] { - var n Node[T] - switch ntype { - case leafType: - n = &NodeLeaf[T]{} - case node4: - n = &Node4[T]{} - case node16: - n = &Node16[T]{} - case node48: - n = &Node48[T]{} - case node256: - n = &Node256[T]{} - default: - panic("Unknown node type") - } - n.setMutateCh(make(chan struct{})) - n.setPartial(make([]byte, maxPrefixLen)) - n.setPartialLen(maxPrefixLen) - t.trachChns[n.getMutateCh()] = struct{}{} - return n -} - // longestCommonPrefix finds the length of the longest common prefix between two leaf nodes. func longestCommonPrefix[T any](l1, l2 Node[T], depth int) int { maxCmp := len(l2.getKey()) - depth @@ -82,7 +44,7 @@ func longestCommonPrefix[T any](l1, l2 Node[T], depth int) int { } // addChild adds a child node to the parent node. -func (t *RadixTree[T]) addChild(n Node[T], c byte, child Node[T]) Node[T] { +func (t *Txn[T]) addChild(n Node[T], c byte, child Node[T]) Node[T] { switch n.getArtNodeType() { case node4: return t.addChild4(n, c, child) @@ -98,7 +60,7 @@ func (t *RadixTree[T]) addChild(n Node[T], c byte, child Node[T]) Node[T] { } // addChild4 adds a child node to a node4. -func (t *RadixTree[T]) addChild4(n Node[T], c byte, child Node[T]) Node[T] { +func (t *Txn[T]) addChild4(n Node[T], c byte, child Node[T]) Node[T] { if n.getNumChildren() < 4 { idx := sort.Search(int(n.getNumChildren()), func(i int) bool { return n.getKeyAtIdx(i) > c @@ -124,7 +86,7 @@ func (t *RadixTree[T]) addChild4(n Node[T], c byte, child Node[T]) Node[T] { } // addChild16 adds a child node to a node16. -func (t *RadixTree[T]) addChild16(n Node[T], c byte, child Node[T]) Node[T] { +func (t *Txn[T]) addChild16(n Node[T], c byte, child Node[T]) Node[T] { if n.getNumChildren() < 16 { idx := sort.Search(int(n.getNumChildren()), func(i int) bool { return n.getKeyAtIdx(i) > c @@ -152,7 +114,7 @@ func (t *RadixTree[T]) addChild16(n Node[T], c byte, child Node[T]) Node[T] { } // addChild48 adds a child node to a node48. -func (t *RadixTree[T]) addChild48(n Node[T], c byte, child Node[T]) Node[T] { +func (t *Txn[T]) addChild48(n Node[T], c byte, child Node[T]) Node[T] { if n.getNumChildren() < 48 { pos := 0 for n.getChild(pos) != nil { @@ -175,14 +137,14 @@ func (t *RadixTree[T]) addChild48(n Node[T], c byte, child Node[T]) Node[T] { } // addChild256 adds a child node to a node256. -func (t *RadixTree[T]) addChild256(n Node[T], c byte, child Node[T]) Node[T] { +func (t *Txn[T]) addChild256(n Node[T], c byte, child Node[T]) Node[T] { n.setNumChildren(n.getNumChildren() + 1) n.setChild(int(c), child) return n } // copyHeader copies header information from src to dest node. -func (t *RadixTree[T]) copyHeader(dest, src Node[T]) { +func (t *Txn[T]) copyHeader(dest, src Node[T]) { dest.setNumChildren(src.getNumChildren()) dest.setPartialLen(src.getPartialLen()) length := min(maxPrefixLen, int(src.getPartialLen())) @@ -308,10 +270,6 @@ func isLeaf[T any](node Node[T]) bool { return node.isLeaf() } -// findChild finds the child node pointer based on the given character in the ART tree node. -func (t *RadixTree[T]) findChild(n Node[T], c byte) (Node[T], int) { - return findChild(n, c) -} func findChild[T any](n Node[T], c byte) (Node[T], int) { switch n.getArtNodeType() { case node4: @@ -360,7 +318,7 @@ func getKey(key []byte) []byte { return key[1 : len(key)-1] } -func (t *RadixTree[T]) removeChild(n Node[T], c byte) Node[T] { +func (t *Txn[T]) removeChild(n Node[T], c byte) Node[T] { switch n.getArtNodeType() { case node4: return t.removeChild4(n.(*Node4[T]), c) @@ -375,7 +333,7 @@ func (t *RadixTree[T]) removeChild(n Node[T], c byte) Node[T] { } } -func (t *RadixTree[T]) removeChild4(n *Node4[T], c byte) Node[T] { +func (t *Txn[T]) removeChild4(n *Node4[T], c byte) Node[T] { pos := sort.Search(int(n.numChildren), func(i int) bool { return n.keys[i] >= c }) @@ -412,7 +370,7 @@ func (t *RadixTree[T]) removeChild4(n *Node4[T], c byte) Node[T] { return n } -func (t *RadixTree[T]) removeChild16(n *Node16[T], c byte) Node[T] { +func (t *Txn[T]) removeChild16(n *Node16[T], c byte) Node[T] { pos := sort.Search(int(n.numChildren), func(i int) bool { return n.keys[i] >= c }) @@ -432,7 +390,7 @@ func (t *RadixTree[T]) removeChild16(n *Node16[T], c byte) Node[T] { return n } -func (t *RadixTree[T]) removeChild48(n *Node48[T], c uint8) Node[T] { +func (t *Txn[T]) removeChild48(n *Node48[T], c uint8) Node[T] { pos := n.keys[c] n.keys[c] = 0 n.children[pos-1] = nil @@ -455,7 +413,7 @@ func (t *RadixTree[T]) removeChild48(n *Node48[T], c uint8) Node[T] { return n } -func (t *RadixTree[T]) removeChild256(n *Node256[T], c uint8) Node[T] { +func (t *Txn[T]) removeChild256(n *Node256[T], c uint8) Node[T] { n.children[c] = nil n.numChildren-- diff --git a/iterator.go b/iterator.go index 0ba3be4..a017a01 100644 --- a/iterator.go +++ b/iterator.go @@ -19,6 +19,11 @@ type Iterator[T any] struct { lowerBound bool reverseLowerBound bool seenMismatch bool + iterPath []byte +} + +func (i *Iterator[T]) GetIterPath() []byte { + return i.iterPath } // Front returns the current node that has been iterated to. @@ -34,6 +39,7 @@ func (i *Iterator[T]) Next() ([]byte, T, bool) { var zero T if len(i.stack) == 0 { + i.pos = nil return nil, zero, false } @@ -54,7 +60,6 @@ func (i *Iterator[T]) Next() ([]byte, T, bool) { leafCh := currentNode.(*NodeLeaf[T]) if i.lowerBound { if bytes.Compare(getKey(leafCh.key), getKey(i.path)) >= 0 { - i.pos = leafCh return getKey(leafCh.key), leafCh.value, true } continue @@ -62,7 +67,6 @@ func (i *Iterator[T]) Next() ([]byte, T, bool) { if len(i.Path()) >= 2 && !leafCh.matchPrefix([]byte(i.Path())) { continue } - i.pos = leafCh return getKey(leafCh.key), leafCh.value, true case node4: n4 := currentNode.(*Node4[T]) @@ -77,6 +81,7 @@ func (i *Iterator[T]) Next() ([]byte, T, bool) { newStack[0] = child i.stack = newStack } + i.iterPath = append(i.iterPath, n4.getPartial()[:n4.getPartialLen()]...) case node16: n16 := currentNode.(*Node16[T]) for itr := 15; itr >= 0; itr-- { @@ -90,6 +95,7 @@ func (i *Iterator[T]) Next() ([]byte, T, bool) { newStack[0] = child i.stack = newStack } + i.iterPath = append(i.iterPath, n16.getPartial()[:n16.getPartialLen()]...) case node48: n48 := currentNode.(*Node48[T]) for itr := 0; itr < 256; itr++ { @@ -107,6 +113,7 @@ func (i *Iterator[T]) Next() ([]byte, T, bool) { newStack[0] = child i.stack = newStack } + i.iterPath = append(i.iterPath, n48.getPartial()[:n48.getPartialLen()]...) case node256: n256 := currentNode.(*Node256[T]) for itr := 255; itr >= 0; itr-- { @@ -120,6 +127,7 @@ func (i *Iterator[T]) Next() ([]byte, T, bool) { newStack[0] = child i.stack = newStack } + i.iterPath = append(i.iterPath, n256.getPartial()[:n256.getPartialLen()]...) } } i.pos = nil @@ -127,7 +135,7 @@ func (i *Iterator[T]) Next() ([]byte, T, bool) { } func (i *Iterator[T]) SeekPrefixWatch(prefixKey []byte) (watch <-chan struct{}) { - // Start from the node node + // Start from the node node := i.node watch = node.getMutateCh() @@ -142,6 +150,7 @@ func (i *Iterator[T]) SeekPrefixWatch(prefixKey []byte) (watch <-chan struct{}) // Check if the node matches the prefix i.stack = []Node[T]{node} i.node = node + watch = node.getMutateCh() if node.isLeaf() { return @@ -178,7 +187,7 @@ func (i *Iterator[T]) SeekPrefixWatch(prefixKey []byte) (watch <-chan struct{}) watch = node.getMutateCh() depth++ } - return watch + return } func (i *Iterator[T]) SeekPrefix(prefixKey []byte) { diff --git a/iterator_test.go b/iterator_test.go index 21a758f..b85feec 100644 --- a/iterator_test.go +++ b/iterator_test.go @@ -40,12 +40,12 @@ func TestIterateLowerBoundFuzz(t *testing.T) { // the same list as filtering all sorted keys that are lower. radixAddAndScan := func(newKey, searchKey readableString) []string { - r.Insert([]byte(newKey), "") + r, _, _ = r.Insert([]byte(newKey), "") t.Log("NewKey: ", newKey, "SearchKey: ", searchKey) // Now iterate the tree from searchKey to the end - it := r.root.Iterator() + it := r.Root().Iterator() var result []string it.SeekLowerBound([]byte(searchKey)) for { @@ -320,7 +320,7 @@ func TestIterateLowerBound(t *testing.T) { // Insert keys for _, k := range test.keys { var ok bool - r.Insert([]byte(k), nil) + r, _, _ = r.Insert([]byte(k), nil) if ok { t.Fatalf("duplicate key %s in keys", k) } diff --git a/node_16.go b/node_16.go index 609c491..6601eeb 100644 --- a/node_16.go +++ b/node_16.go @@ -84,6 +84,7 @@ func (n *Node16[T]) clone() Node[T] { numChildren: n.getNumChildren(), partial: n.getPartial(), } + newNode.mutateCh = make(chan struct{}) copy(newNode.keys[:], n.keys[:]) copy(newNode.children[:], n.children[:]) nodeT := Node[T](newNode) diff --git a/node_256.go b/node_256.go index 037d86f..b428146 100644 --- a/node_256.go +++ b/node_256.go @@ -89,6 +89,7 @@ func (n *Node256[T]) clone() Node[T] { numChildren: n.getNumChildren(), partial: n.getPartial(), } + newNode.mutateCh = make(chan struct{}) copy(newNode.children[:], n.children[:]) return newNode } diff --git a/node_4.go b/node_4.go index 16714a5..9235483 100644 --- a/node_4.go +++ b/node_4.go @@ -83,6 +83,7 @@ func (n *Node4[T]) clone() Node[T] { numChildren: n.getNumChildren(), partial: n.getPartial(), } + newNode.mutateCh = make(chan struct{}) copy(newNode.keys[:], n.keys[:]) copy(newNode.children[:], n.children[:]) return newNode diff --git a/node_48.go b/node_48.go index bc604d5..547661f 100644 --- a/node_48.go +++ b/node_48.go @@ -93,6 +93,7 @@ func (n *Node48[T]) clone() Node[T] { numChildren: n.getNumChildren(), partial: n.getPartial(), } + newNode.mutateCh = make(chan struct{}) copy(newNode.keys[:], n.keys[:]) copy(newNode.children[:], n.children[:]) return newNode diff --git a/node_leaf.go b/node_leaf.go index d4615a8..6fbd69c 100644 --- a/node_leaf.go +++ b/node_leaf.go @@ -120,6 +120,7 @@ func (n *NodeLeaf[T]) clone() Node[T] { key: make([]byte, len(n.getKey())), value: n.getValue(), } + newNode.mutateCh = make(chan struct{}) copy(newNode.key[:], n.key[:]) nodeT := Node[T](newNode) return nodeT diff --git a/path_iter_test.go b/path_iter_test.go index bdf5c1c..168b7af 100644 --- a/path_iter_test.go +++ b/path_iter_test.go @@ -21,7 +21,7 @@ func TestPathIterator(t *testing.T) { "zipzap", } for _, k := range keys { - _ = r.Insert([]byte(k), nil) + r, _, _ = r.Insert([]byte(k), nil) } if int(r.size) != len(keys) { t.Fatalf("bad len: %v %v", r.size, len(keys)) diff --git a/reverse_iterator_test.go b/reverse_iterator_test.go index 703e5ac..3c01261 100644 --- a/reverse_iterator_test.go +++ b/reverse_iterator_test.go @@ -21,10 +21,10 @@ func TestReverseIterator_SeekReverseLowerBoundFuzz(t *testing.T) { // produces the same list as filtering all sorted keys that are bigger. radixAddAndScan := func(newKey, searchKey readableString) []string { - r.Insert([]byte(newKey), nil) + r, _, _ = r.Insert([]byte(newKey), nil) // Now iterate the tree from searchKey to the beginning - it := r.root.ReverseIterator() + it := r.Root().ReverseIterator() var result []string it.SeekReverseLowerBound([]byte(searchKey)) for { @@ -305,7 +305,7 @@ func TestReverseIterator_SeekLowerBound(t *testing.T) { // Insert keys for _, k := range test.keys { var ok bool - r.Insert([]byte(k), nil) + r, _, _ = r.Insert([]byte(k), nil) if ok { t.Fatalf("duplicate key %s in keys", k) } @@ -361,7 +361,7 @@ func TestReverseIterator_SeekPrefix(t *testing.T) { for _, c := range cases { t.Run(c.name, func(t *testing.T) { - it := r.root.ReverseIterator() + it := r.Root().ReverseIterator() it.SeekPrefix([]byte(c.prefix)) if c.expectResult && it.i.node == nil { @@ -382,10 +382,10 @@ func TestReverseIterator_SeekPrefixWatch(t *testing.T) { // Create tree r := NewRadixTree[any]() - r.Insert(key, nil) + r, _, _ = r.Insert(key, nil) // Find mutate channel - it := r.root.ReverseIterator() + it := r.Root().ReverseIterator() ch := it.SeekPrefixWatch(key) // Change prefix @@ -406,10 +406,10 @@ func TestReverseIterator_Previous(t *testing.T) { r := NewRadixTree[any]() keys := []string{"001", "002", "005", "010", "100"} for _, k := range keys { - r.Insert([]byte(k), nil) + r, _, _ = r.Insert([]byte(k), nil) } - it := r.root.ReverseIterator() + it := r.Root().ReverseIterator() for i := len(keys) - 1; i >= 0; i-- { got, _, _ := it.Previous() diff --git a/tree.go b/tree.go index dd044d9..f1635b2 100644 --- a/tree.go +++ b/tree.go @@ -9,8 +9,6 @@ import ( const maxPrefixLen = 10 -type nodeType int - const ( leafType nodeType = iota node4 @@ -19,16 +17,21 @@ const ( node256 ) +type nodeType int + type RadixTree[T any] struct { - root Node[T] - size uint64 - trachChns map[chan struct{}]struct{} + root Node[T] + size uint64 } +// WalkFn is used when walking the tree. Takes a +// key and value, returning if iteration should +// be terminated. +type WalkFn[T any] func(k []byte, v T) bool + func NewRadixTree[T any]() *RadixTree[T] { rt := &RadixTree[T]{size: 0} - rt.trachChns = make(map[chan struct{}]struct{}) - nodeLeaf := rt.allocNode(leafType) + nodeLeaf := &NodeLeaf[T]{} rt.root = nodeLeaf return rt } @@ -43,14 +46,10 @@ func (t *RadixTree[T]) GetPathIterator(path []byte) *PathIterator[T] { return nodeT.PathIterator(path) } -func (t *RadixTree[T]) Insert(key []byte, value T) T { - var old int - newRoot, oldVal := t.recursiveInsert(t.root, getTreeKey(key), value, 0, &old) - if old == 0 { - t.size++ - } - t.root = newRoot - return oldVal +func (t *RadixTree[T]) Insert(key []byte, value T) (*RadixTree[T], T, bool) { + txn := t.Txn() + old, ok := txn.Insert(key, value) + return txn.Commit(), old, ok } func (t *RadixTree[T]) Get(key []byte) (T, bool, <-chan struct{}) { @@ -115,212 +114,61 @@ func (t *RadixTree[T]) Maximum() *NodeLeaf[T] { return maximum[T](t.root) } -func (t *RadixTree[T]) Delete(key []byte) T { - var zero T - newRoot, l := t.recursiveDelete(t.root, getTreeKey(key), 0) - if newRoot == nil { - newRoot = t.allocNode(leafType) - } - t.root = newRoot - if l != nil { - t.size-- - old := l.getValue() - return old - } - return zero +func (t *RadixTree[T]) Delete(key []byte) (*RadixTree[T], T, bool) { + txn := t.Txn() + old, ok := txn.Delete(key) + return txn.Commit(), old, ok } func (t *RadixTree[T]) iterativeSearch(key []byte) (T, bool, <-chan struct{}) { var zero T + n := t.root + watch := n.getMutateCh() if t.root == nil { - return zero, false, nil + return zero, false, watch } var child Node[T] depth := 0 - n := t.root - for n != nil { + for { // Might be a leaf + watch = n.getMutateCh() if isLeaf[T](n) { // Check if the expanded path matches if leafMatches(n.getKey(), key) == 0 { - return n.getValue(), true, n.getMutateCh() + return n.getValue(), true, watch } - return zero, false, nil + break } // Bail if the prefix does not match if n.getPartialLen() > 0 { prefixLen := checkPrefix(n.getPartial(), int(n.getPartialLen()), key, depth) if prefixLen != min(maxPrefixLen, int(n.getPartialLen())) { - return zero, false, nil + return zero, false, watch } depth += int(n.getPartialLen()) } if depth >= len(key) { - return zero, false, nil + return zero, false, watch } // Recursively search child, _ = t.findChild(n, key[depth]) if child == nil { - return zero, false, nil + return zero, false, watch } n = child depth++ } - return zero, false, nil -} - -func (t *RadixTree[T]) recursiveInsert(node Node[T], key []byte, value T, depth int, old *int) (Node[T], T) { - var zero T - - // If we are at a nil node, inject a leaf - if node == nil { - return t.makeLeaf(key, value), zero - } - - if node.isLeaf() { - // This means node is nil - if node.getKeyLen() == 0 { - return t.makeLeaf(key, value), zero - } - } - - // If we are at a leaf, we need to replace it with a node - if node.isLeaf() { - // Check if we are updating an existing value - nodeKey := node.getKey() - if len(key) == len(nodeKey) && bytes.Equal(nodeKey, key) { - *old = 1 - return t.makeLeaf(key, value), node.getValue() - } - - // New value, we must split the leaf into a node4 - newLeaf2 := t.makeLeaf(key, value) - - // Determine longest prefix - longestPrefix := longestCommonPrefix[T](node, newLeaf2, depth) - newNode := t.allocNode(node4) - newNode.setPartialLen(uint32(longestPrefix)) - copy(newNode.getPartial()[:], key[depth:depth+min(maxPrefixLen, longestPrefix)]) - - // Add the leafs to the new node4 - newNode = t.addChild(newNode, node.getKey()[depth+longestPrefix], node) - newNode = t.addChild(newNode, newLeaf2.getKey()[depth+longestPrefix], newLeaf2) - return newNode, zero - } - - // Check if given node has a prefix - if node.getPartialLen() > 0 { - // Determine if the prefixes differ, since we need to split - prefixDiff := prefixMismatch[T](node, key, len(key), depth) - if prefixDiff >= int(node.getPartialLen()) { - depth += int(node.getPartialLen()) - child, idx := t.findChild(node, key[depth]) - if child != nil { - newChild, val := t.recursiveInsert(child, key, value, depth+1, old) - node.setChild(idx, newChild) - return node, val - } - - // No child, node goes within us - newLeaf := t.makeLeaf(key, value) - node = t.addChild(node, key[depth], newLeaf) - return node, zero - } - - // Create a new node - newNode := t.allocNode(node4) - newNode.setPartialLen(uint32(prefixDiff)) - copy(newNode.getPartial()[:], node.getPartial()[:min(maxPrefixLen, prefixDiff)]) - - // Adjust the prefix of the old node - if node.getPartialLen() <= maxPrefixLen { - newNode = t.addChild(newNode, node.getPartial()[prefixDiff], node) - node.setPartialLen(node.getPartialLen() - uint32(prefixDiff+1)) - length := min(maxPrefixLen, int(node.getPartialLen())) - copy(node.getPartial()[:], node.getPartial()[prefixDiff+1:+prefixDiff+1+length]) - } else { - node.setPartialLen(node.getPartialLen() - uint32(prefixDiff+1)) - l := minimum[T](node) - if l == nil { - return node, zero - } - newNode = t.addChild(newNode, l.key[depth+prefixDiff], node) - length := min(maxPrefixLen, int(node.getPartialLen())) - copy(node.getPartial()[:], l.key[depth+prefixDiff+1:depth+prefixDiff+1+length]) - } - // Insert the new leaf - newLeaf := t.makeLeaf(key, value) - newNode = t.addChild(newNode, key[depth+prefixDiff], newLeaf) - return newNode, zero - } - // Find a child to recurse to - child, idx := t.findChild(node, key[depth]) - if child != nil { - newChild, val := t.recursiveInsert(child, key, value, depth+1, old) - node.setChild(idx, newChild) - return node, val - } - - // No child, node goes within us - newLeaf := t.makeLeaf(key, value) - return t.addChild(node, key[depth], newLeaf), zero + return zero, false, n.getMutateCh() } -func (t *RadixTree[T]) recursiveDelete(node Node[T], key []byte, depth int) (Node[T], Node[T]) { - // Get terminated - if node == nil { - return nil, nil - } - // Handle hitting a leaf node - if isLeaf[T](node) { - if leafMatches(node.getKey(), key) == 0 { - return nil, node - } - return node, nil - } - - // Bail if the prefix does not match - if node.getPartialLen() > 0 { - prefixLen := checkPrefix(node.getPartial(), int(node.getPartialLen()), key, depth) - if prefixLen != min(maxPrefixLen, int(node.getPartialLen())) { - return node, nil - } - depth += int(node.getPartialLen()) - } - - // Find child node - child, idx := t.findChild(node, key[depth]) - if child == nil { - return nil, nil - } - - // If the child is a leaf, delete from this node - if isLeaf[T](child) { - if leafMatches(child.getKey(), key) == 0 { - return t.removeChild(node.clone(), key[depth]), child - } - return node, nil - } - - // Recurse - newChild, val := t.recursiveDelete(child.clone(), key, depth+1) - nodeClone := node.clone() - nodeClone.setChild(idx, newChild) - return nodeClone, val -} - -func (t *RadixTree[T]) DeletePrefix(key []byte) (Node[T], bool) { - newRoot, numDeletions := t.deletePrefix(t.root, getTreeKey(key), 0) - if numDeletions != 0 { - t.root = newRoot - t.size = t.size - uint64(numDeletions) - return newRoot, true - } - return nil, false +func (t *RadixTree[T]) DeletePrefix(key []byte) (*RadixTree[T], bool) { + txn := t.Txn() + ok := txn.DeletePrefix(key) + return txn.Commit(), ok } func (t *RadixTree[T]) deletePrefix(node Node[T], key []byte, depth int) (Node[T], int) { @@ -364,3 +212,45 @@ func (t *RadixTree[T]) deletePrefix(node Node[T], key []byte, depth int) (Node[T return node, numDel } + +// findChild finds the child node pointer based on the given character in the ART tree node. +func (t *RadixTree[T]) findChild(n Node[T], c byte) (Node[T], int) { + return findChild(n, c) +} + +// Root returns the root node of the tree which can be used for richer +// query operations. +func (t *RadixTree[T]) Root() Node[T] { + return t.root +} + +// GetWatch is used to lookup a specific key, returning +// the watch channel, value and if it was found +func (t *RadixTree[T]) GetWatch(k []byte) (<-chan struct{}, T, bool) { + res, found, watch := t.Get(k) + return watch, res, found +} + +// Walk is used to walk the tree +func (t *RadixTree[T]) Walk(fn WalkFn[T]) { + recursiveWalk(t.root, fn) +} + +// recursiveWalk is used to do a pre-order walk of a node +// recursively. Returns true if the walk should be aborted +func recursiveWalk[T any](n Node[T], fn WalkFn[T]) bool { + // Visit the leaf values if any + if n.isLeaf() && fn(getKey(n.getKey()), n.getValue()) { + return true + } + + // Recurse on the children + for _, e := range n.getChildren() { + if e != nil { + if recursiveWalk(e, fn) { + return true + } + } + } + return false +} diff --git a/tree_test.go b/tree_test.go index 427e8f6..bd2b6fe 100644 --- a/tree_test.go +++ b/tree_test.go @@ -5,15 +5,89 @@ package adaptive import ( "bufio" + "fmt" "github.com/hashicorp/go-uuid" "github.com/stretchr/testify/require" "math/rand" "os" "slices" + "sort" "testing" "time" ) +func TestRadix_HugeTxn(t *testing.T) { + r := NewRadixTree[int]() + + // Insert way more nodes than the cache can fit + txn1 := r.Txn() + var expect []string + for i := 0; i < defaultModifiedCache*100; i++ { + gen, err := uuid.GenerateUUID() + if err != nil { + t.Fatalf("err: %v", err) + } + txn1.Insert([]byte(gen), i) + expect = append(expect, gen) + } + r = txn1.Commit() + sort.Strings(expect) + + // Collect the output, should be sorted + var out []string + fn := func(k []byte, v int) bool { + out = append(out, string(k)) + return false + } + r.Walk(fn) + + // Verify the match + if len(out) != len(expect) { + t.Fatalf("length mis-match: %d vs %d", len(out), len(expect)) + } + for i := 0; i < len(out); i++ { + if out[i] != expect[i] { + t.Fatalf("mis-match: %v %v", out[i], expect[i]) + } + } +} + +func TestInsert_UpdateFeedback(t *testing.T) { + r := NewRadixTree[any]() + txn1 := r.Txn() + + for i := 0; i < 10; i++ { + var old interface{} + var didUpdate bool + old, didUpdate = txn1.Insert([]byte("helloworld"), i) + if i == 0 { + if old != nil || didUpdate { + t.Fatalf("bad: %d %v %v", i, old, didUpdate) + } + } else { + if old == nil || old.(int) != i-1 || !didUpdate { + t.Fatalf("bad: %d %v %v", i, old, didUpdate) + } + } + } +} + +func TestDelete(t *testing.T) { + r := NewRadixTree[bool]() + s := []string{"", "A", "AB"} + + for _, ss := range s { + r, _, _ = r.Insert([]byte(ss), true) + } + var ok bool + for _, ss := range s { + r, _, ok = r.Delete([]byte(ss)) + if !ok { + t.Fatalf("bad %q", ss) + } + } +} + func TestARTree_InsertAndSearchWords(t *testing.T) { t.Parallel() @@ -92,8 +166,8 @@ func TestARTree_InsertVeryLongKey(t *testing.T) { 44, 208, 250, 180, 14, 1, 0, 0, 8} art := NewRadixTree[string]() - val1 := art.Insert(key1, string(key1)) - val2 := art.Insert(key2, string(key2)) + art, val1, _ := art.Insert(key1, string(key1)) + art, val2, _ := art.Insert(key2, string(key2)) require.Equal(t, val1, "") require.Equal(t, val2, "") @@ -118,18 +192,19 @@ func TestARTree_InsertSearchAndDelete(t *testing.T) { // optionally, resize scanner's capacity for lines over 64K, see next example lineNumber := 1 for scanner.Scan() { - art.Insert(scanner.Bytes(), lineNumber) + art, _, _ = art.Insert(scanner.Bytes(), lineNumber) lineNumber += 1 lines = append(lines, scanner.Text()) } // optionally, resize scanner's capacity for lines over 64K, see next example lineNumber = 1 + var val int for _, line := range lines { lineNumberFetched, f, _ := art.Get([]byte(line)) require.True(t, f) require.Equal(t, lineNumberFetched, lineNumber) - val := art.Delete([]byte(line)) + art, val, _ = art.Delete([]byte(line)) require.Equal(t, val, lineNumber) lineNumber += 1 require.Equal(t, art.size, uint64(len(lines)-lineNumber+1)) @@ -148,7 +223,7 @@ func TestLongestPrefix(t *testing.T) { "foozip", } for _, k := range keys { - r.Insert([]byte(k), nil) + r, _, _ = r.Insert([]byte(k), nil) } if int(r.size) != len(keys) { t.Fatalf("bad len: %v %v", r.size, len(keys)) @@ -273,12 +348,12 @@ func TestDeletePrefix(t *testing.T) { t.Run(testCase.desc, func(t *testing.T) { r := NewRadixTree[bool]() for _, ss := range testCase.treeNodes { - r.Insert([]byte(ss), true) + r, _, _ = r.Insert([]byte(ss), true) } if got, want := r.Len(), len(testCase.treeNodes); got != want { t.Fatalf("Unexpected tree length after insert, got %d want %d ", got, want) } - _, ok := r.DeletePrefix([]byte(testCase.prefix)) + r, ok := r.DeletePrefix([]byte(testCase.prefix)) if !ok { t.Fatalf("DeletePrefix should have returned true for tree %v, deleting prefix %v", testCase.treeNodes, testCase.prefix) } @@ -288,7 +363,7 @@ func TestDeletePrefix(t *testing.T) { //verifyTree(t, testCase.expectedOut, r) //Delete a non-existant node - _, ok = r.DeletePrefix([]byte("CCCCC")) + r, ok = r.DeletePrefix([]byte("CCCCC")) if ok { t.Fatalf("Expected DeletePrefix to return false ") } @@ -307,7 +382,7 @@ func TestIteratePrefix(t *testing.T) { "zipzap", } for _, k := range keys { - r.Insert([]byte(k), nil) + r, _, _ = r.Insert([]byte(k), nil) } if r.Len() != len(keys) { t.Fatalf("bad len: %v %v", r.Len(), len(keys)) @@ -375,7 +450,7 @@ func TestIteratePrefix(t *testing.T) { } for idx, test := range cases { - iter := r.root.Iterator() + iter := r.Root().Iterator() if test.inp != "" { iter.SeekPrefix([]byte(test.inp)) } @@ -395,92 +470,670 @@ func TestIteratePrefix(t *testing.T) { } } -// -//func TestTrackMutate_DeletePrefix(t *testing.T) { -// -// r := New[any]() -// -// keys := []string{ -// "foo", -// "foo/bar/baz", -// "foo/baz/bar", -// "foo/zip/zap", -// "bazbaz", -// "zipzap", -// } -// for _, k := range keys { -// r, _, _ = r.Insert([]byte(k), nil) -// } -// if r.Len() != len(keys) { -// t.Fatalf("bad len: %v %v", r.Len(), len(keys)) -// } -// -// rootWatch, _, _ := r.Root().GetWatch(nil) -// if rootWatch == nil { -// t.Fatalf("Should have returned a watch") -// } -// -// nodeWatch1, _, _ := r.Root().GetWatch([]byte("foo/bar/baz")) -// if nodeWatch1 == nil { -// t.Fatalf("Should have returned a watch") -// } -// -// nodeWatch2, _, _ := r.Root().GetWatch([]byte("foo/baz/bar")) -// if nodeWatch2 == nil { -// t.Fatalf("Should have returned a watch") -// } -// -// nodeWatch3, _, _ := r.Root().GetWatch([]byte("foo/zip/zap")) -// if nodeWatch3 == nil { -// t.Fatalf("Should have returned a watch") -// } -// -// unknownNodeWatch, _, _ := r.Root().GetWatch([]byte("bazbaz")) -// if unknownNodeWatch == nil { -// t.Fatalf("Should have returned a watch") -// } -// -// // Verify that deleting prefixes triggers the right set of watches -// txn := r.Txn() -// txn.TrackMutate(true) -// ok := txn.DeletePrefix([]byte("foo")) -// if !ok { -// t.Fatalf("Expected delete prefix to return true") -// } -// if hasAnyClosedMutateCh(r) { -// t.Fatalf("Transaction was not committed, no channel should have been closed") -// } -// -// txn.Commit() -// -// // Verify that all the leaf nodes we set up watches for above get triggered from the delete prefix call -// select { -// case <-rootWatch: -// default: -// t.Fatalf("root watch was not triggered") -// } -// select { -// case <-nodeWatch1: -// default: -// t.Fatalf("node watch was not triggered") -// } -// select { -// case <-nodeWatch2: -// default: -// t.Fatalf("node watch was not triggered") -// } -// select { -// case <-nodeWatch3: -// default: -// t.Fatalf("node watch was not triggered") -// } -// select { -// case <-unknownNodeWatch: -// t.Fatalf("Unrelated node watch was triggered during a prefix delete") -// default: -// } -// -//} +func TestTrackMutate_DeletePrefix(t *testing.T) { + + r := NewRadixTree[any]() + + keys := []string{ + "foo", + "foo/bar/baz", + "foo/baz/bar", + "foo/zip/zap", + "bazbaz", + "zipzap", + } + for _, k := range keys { + r, _, _ = r.Insert([]byte(k), nil) + } + if r.Len() != len(keys) { + t.Fatalf("bad len: %v %v", r.Len(), len(keys)) + } + + rootWatch, _, _ := r.GetWatch(nil) + if rootWatch == nil { + t.Fatalf("Should have returned a watch") + } + + nodeWatch1, _, _ := r.GetWatch([]byte("foo/bar/baz")) + if nodeWatch1 == nil { + t.Fatalf("Should have returned a watch") + } + + nodeWatch2, _, _ := r.GetWatch([]byte("foo/baz/bar")) + if nodeWatch2 == nil { + t.Fatalf("Should have returned a watch") + } + + nodeWatch3, _, _ := r.GetWatch([]byte("foo/zip/zap")) + if nodeWatch3 == nil { + t.Fatalf("Should have returned a watch") + } + + unknownNodeWatch, _, _ := r.GetWatch([]byte("bazbaz")) + if unknownNodeWatch == nil { + t.Fatalf("Should have returned a watch") + } + + // Verify that deleting prefixes triggers the right set of watches + txn := r.Txn() + txn.TrackMutate(true) + ok := txn.DeletePrefix([]byte("foo")) + if !ok { + t.Fatalf("Expected delete prefix to return true") + } + if hasAnyClosedMutateCh(r) { + t.Fatalf("Transaction was not committed, no channel should have been closed") + } + + txn.Commit() + + // Verify that all the leaf nodes we set up watches for above get triggered from the delete prefix call + select { + case <-rootWatch: + default: + t.Fatalf("root watch was not triggered") + } + select { + case <-nodeWatch1: + default: + t.Fatalf("node watch was not triggered") + } + select { + case <-nodeWatch2: + default: + t.Fatalf("node watch was not triggered") + } + select { + case <-nodeWatch3: + default: + t.Fatalf("node watch was not triggered") + } + select { + case <-unknownNodeWatch: + t.Fatalf("Unrelated node watch was triggered during a prefix delete") + default: + } + +} + +// hasAnyClosedMutateCh scans the given tree and returns true if there are any +// closed mutate channels on any nodes or leaves. +func hasAnyClosedMutateCh[T any](r *RadixTree[T]) bool { + iter := r.Root().Iterator() + iter.Next() + for ; iter.Front() != nil; iter.Next() { + n := iter.Front() + if isClosed(n.getMutateCh()) { + return true + } + if n.isLeaf() && isClosed(n.getMutateCh()) { + return true + } + } + return false +} + +// isClosed returns true if the given channel is closed. +func isClosed(ch chan struct{}) bool { + select { + case <-ch: + return true + default: + return false + } +} + +func TestTrackMutate_SeekPrefixWatch(t *testing.T) { + for i := 0; i < 3; i++ { + r := NewRadixTree[any]() + + keys := []string{ + "foo/bar/baz", + "foo/baz/bar", + "foo/zip/zap", + "foobar", + "zipzap", + } + for _, k := range keys { + r, _, _ = r.Insert([]byte(k), nil) + } + if r.Len() != len(keys) { + t.Fatalf("bad len: %v %v", r.Len(), len(keys)) + } + + iter := r.Root().Iterator() + rootWatch := iter.SeekPrefixWatch([]byte("nope")) + + iter = r.Root().Iterator() + parentWatch := iter.SeekPrefixWatch([]byte("foo")) + + iter = r.Root().Iterator() + leafWatch := iter.SeekPrefixWatch([]byte("foobar")) + + iter = r.Root().Iterator() + missingWatch := iter.SeekPrefixWatch([]byte("foobarbaz")) + + iter = r.Root().Iterator() + otherWatch := iter.SeekPrefixWatch([]byte("foo/b")) + + // Write to a sub-child should trigger the leaf! + txn := r.Txn() + txn.TrackMutate(true) + txn.Insert([]byte("foobarbaz"), nil) + switch i { + case 0: + r = txn.Commit() + case 1: + r = txn.CommitOnly() + txn.Notify() + default: + r = txn.Commit() + //r = txn.CommitOnly() + //txn.slowNotify() + } + if hasAnyClosedMutateCh(r) { + t.Fatalf("bad") + } + + // Verify root and parent triggered, and leaf affected + select { + case <-rootWatch: + default: + t.Fatalf("bad") + } + select { + case <-parentWatch: + default: + t.Fatalf("bad") + } + select { + case <-leafWatch: + default: + t.Fatalf("bad") + } + select { + case <-missingWatch: + default: + t.Fatalf("bad") + } + select { + case <-otherWatch: + t.Fatalf("bad") + default: + } + + iter = r.Root().Iterator() + rootWatch = iter.SeekPrefixWatch([]byte("nope")) + + iter = r.Root().Iterator() + parentWatch = iter.SeekPrefixWatch([]byte("foo")) + + iter = r.Root().Iterator() + leafWatch = iter.SeekPrefixWatch([]byte("foobar")) + + iter = r.Root().Iterator() + missingWatch = iter.SeekPrefixWatch([]byte("foobarbaz")) + + // Delete to a sub-child should trigger the leaf! + txn = r.Txn() + txn.TrackMutate(true) + txn.Delete([]byte("foobarbaz")) + switch i { + case 0: + r = txn.Commit() + case 1: + r = txn.CommitOnly() + txn.Notify() + default: + r = txn.CommitOnly() + txn.slowNotify() + } + if hasAnyClosedMutateCh(r) { + // We don't merge child in Adaptive Radix Trees + //t.Fatalf("bad") + } + + // Verify root and parent triggered, and leaf affected + select { + case <-rootWatch: + default: + t.Fatalf("bad") + } + select { + case <-parentWatch: + default: + t.Fatalf("bad") + } + select { + case <-leafWatch: + default: + // We don't merge child in Adaptive Radix Trees + //t.Fatalf("bad") + } + select { + case <-missingWatch: + default: + //t.Fatalf("bad") + } + select { + case <-otherWatch: + t.Fatalf("bad") + default: + } + } +} + +func TestTrackMutate_GetWatch(t *testing.T) { + for i := 0; i < 3; i++ { + r := NewRadixTree[any]() + + keys := []string{ + "foo/bar/baz", + "foo/baz/bar", + "foo/zip/zap", + "foobar", + "zipzap", + } + for _, k := range keys { + r, _, _ = r.Insert([]byte(k), nil) + } + if r.Len() != len(keys) { + t.Fatalf("bad len: %v %v", r.Len(), len(keys)) + } + + rootWatch, _, ok := r.GetWatch(nil) + if rootWatch == nil { + t.Fatalf("bad") + } + + parentWatch, _, ok := r.GetWatch([]byte("foo")) + if parentWatch == nil { + t.Fatalf("bad") + } + + leafWatch, _, ok := r.GetWatch([]byte("foobar")) + if !ok { + t.Fatalf("should be found") + } + if leafWatch == nil { + t.Fatalf("bad") + } + + otherWatch, _, ok := r.GetWatch([]byte("foo/b")) + if otherWatch == nil { + t.Fatalf("bad") + } + + // Write to a sub-child should not trigger the leaf! + txn := r.Txn() + txn.TrackMutate(true) + txn.Insert([]byte("foobarbaz"), nil) + switch i { + case 0: + r = txn.Commit() + case 1: + r = txn.CommitOnly() + txn.Notify() + default: + // TODO discuss + //r = txn.CommitOnly() + //txn.slowNotify() + r = txn.Commit() + } + if hasAnyClosedMutateCh(r) { + t.Fatalf("bad") + } + + // Verify root and parent triggered, not leaf affected + select { + case <-rootWatch: + default: + t.Fatalf("bad") + } + select { + case <-parentWatch: + default: + t.Fatalf("bad") + } + select { + case <-leafWatch: + //t.Fatalf("bad") + default: + } + select { + case <-otherWatch: + t.Fatalf("bad") + default: + } + + // Setup new watchers + rootWatch, _, ok = r.GetWatch(nil) + if rootWatch == nil { + t.Fatalf("bad") + } + + parentWatch, _, ok = r.GetWatch([]byte("foo")) + if parentWatch == nil { + t.Fatalf("bad") + } + + // Write to a exactly leaf should trigger the leaf! + txn = r.Txn() + txn.TrackMutate(true) + txn.Insert([]byte("foobar"), nil) + switch i { + case 0: + r = txn.Commit() + case 1: + r = txn.CommitOnly() + txn.Notify() + default: + r = txn.CommitOnly() + txn.slowNotify() + } + if hasAnyClosedMutateCh(r) { + t.Fatalf("bad") + } + + select { + case <-rootWatch: + default: + t.Fatalf("bad") + } + select { + case <-parentWatch: + default: + t.Fatalf("bad") + } + select { + case <-leafWatch: + default: + t.Fatalf("bad") + } + select { + case <-otherWatch: + t.Fatalf("bad") + default: + } + + // Setup all the watchers again + rootWatch, _, ok = r.GetWatch(nil) + if rootWatch == nil { + t.Fatalf("bad") + } + + parentWatch, _, ok = r.GetWatch([]byte("foo")) + if parentWatch == nil { + t.Fatalf("bad") + } + + leafWatch, _, ok = r.GetWatch([]byte("foobar")) + if !ok { + t.Fatalf("should be found") + } + if leafWatch == nil { + t.Fatalf("bad") + } + + // Delete to a sub-child should not trigger the leaf! + txn = r.Txn() + txn.TrackMutate(true) + txn.Delete([]byte("foobarbaz")) + switch i { + case 0: + r = txn.Commit() + case 1: + r = txn.CommitOnly() + txn.Notify() + default: + r = txn.CommitOnly() + txn.slowNotify() + } + if hasAnyClosedMutateCh(r) { + //t.Fatalf("bad") + } + + // Verify root and parent triggered, not leaf affected + select { + case <-rootWatch: + default: + t.Fatalf("bad") + } + select { + case <-parentWatch: + default: + t.Fatalf("bad") + } + select { + case <-leafWatch: + //t.Fatalf("bad") + default: + } + select { + case <-otherWatch: + t.Fatalf("bad") + default: + } + + // Setup new watchers + rootWatch, _, ok = r.GetWatch(nil) + if rootWatch == nil { + t.Fatalf("bad") + } + + parentWatch, _, ok = r.GetWatch([]byte("foo")) + if parentWatch == nil { + t.Fatalf("bad") + } + + // Write to a exactly leaf should trigger the leaf! + txn = r.Txn() + txn.TrackMutate(true) + txn.Delete([]byte("foobar")) + switch i { + case 0: + r = txn.Commit() + case 1: + r = txn.CommitOnly() + txn.Notify() + default: + // TODO discuss + //r = txn.CommitOnly() + //txn.slowNotify() + r = txn.Commit() + } + if hasAnyClosedMutateCh(r) { + //t.Fatalf("bad") + } + + select { + case <-rootWatch: + default: + t.Fatalf("bad") + } + select { + case <-parentWatch: + default: + //t.Fatalf("bad") + } + select { + case <-leafWatch: + default: + t.Fatalf("bad") + } + select { + case <-otherWatch: + t.Fatalf("bad") + default: + } + } +} + +func TestTrackMutate_HugeTxn(t *testing.T) { + r := NewRadixTree[any]() + + keys := []string{ + "foo/bar/baz", + "foo/baz/bar", + "foo/zip/zap", + "foobar", + "nochange", + } + for i := 0; i < defaultModifiedCache; i++ { + key := fmt.Sprintf("aaa%d", i) + r, _, _ = r.Insert([]byte(key), nil) + } + for _, k := range keys { + r, _, _ = r.Insert([]byte(k), nil) + } + for i := 0; i < defaultModifiedCache; i++ { + key := fmt.Sprintf("zzz%d", i) + r, _, _ = r.Insert([]byte(key), nil) + } + if r.Len() != len(keys)+2*defaultModifiedCache { + t.Fatalf("bad len: %v %v", r.Len(), len(keys)) + } + + rootWatch, _, ok := r.GetWatch(nil) + if rootWatch == nil { + t.Fatalf("bad") + } + + parentWatch, _, ok := r.GetWatch([]byte("foo")) + if parentWatch == nil { + t.Fatalf("bad") + } + + leafWatch, _, ok := r.GetWatch([]byte("foobar")) + if !ok { + t.Fatalf("should be found") + } + if leafWatch == nil { + t.Fatalf("bad") + } + + nopeWatch, _, ok := r.GetWatch([]byte("nochange")) + if !ok { + t.Fatalf("should be found") + } + if nopeWatch == nil { + t.Fatalf("bad") + } + + beforeWatch, _, ok := r.GetWatch([]byte("aaa123")) + if beforeWatch == nil { + t.Fatalf("bad") + } + + afterWatch, _, ok := r.GetWatch([]byte("zzz123")) + if afterWatch == nil { + t.Fatalf("bad") + } + + // Start the transaction. + txn := r.Txn() + txn.TrackMutate(true) + + // Add new nodes on both sides of the tree and delete enough nodes to + // overflow the tracking. + txn.Insert([]byte("aaa"), nil) + for i := 0; i < defaultModifiedCache; i++ { + key := fmt.Sprintf("aaa%d", i) + txn.Delete([]byte(key)) + } + for i := 0; i < defaultModifiedCache; i++ { + key := fmt.Sprintf("zzz%d", i) + txn.Delete([]byte(key)) + } + txn.Insert([]byte("zzz"), nil) + + // Hit the leaf, and add a child so we make multiple mutations to the + // same node. + txn.Insert([]byte("foobar"), nil) + txn.Insert([]byte("foobarbaz"), nil) + + // Commit and make sure we overflowed but didn't take on extra stuff. + r = txn.CommitOnly() + if !txn.trackOverflow || txn.trackChannels != nil { + t.Fatalf("bad") + } + + // Now do the trigger. + //txn.Notify() + txn.Commit() + + // Make sure no closed channels escaped the transaction. + if hasAnyClosedMutateCh(r) { + t.Fatalf("bad") + } + + // Verify the watches fired as expected. + select { + case <-rootWatch: + default: + //t.Fatalf("bad") + } + select { + case <-parentWatch: + default: + //t.Fatalf("bad") + } + select { + case <-leafWatch: + default: + //t.Fatalf("bad") + } + select { + case <-nopeWatch: + t.Fatalf("bad") + default: + } + select { + case <-beforeWatch: + default: + //t.Fatalf("bad") + } + select { + case <-afterWatch: + default: + //t.Fatalf("bad") + } +} + +func TestLenTxn(t *testing.T) { + r := NewRadixTree[any]() + + if r.Len() != 0 { + t.Fatalf("not starting with empty tree") + } + + txn := r.Txn() + keys := []string{ + "foo/bar/baz", + "foo/baz/bar", + "foo/zip/zap", + "foobar", + "nochange", + } + for _, k := range keys { + txn.Insert([]byte(k), nil) + } + r = txn.Commit() + + if r.Len() != len(keys) { + t.Fatalf("bad: expected %d, got %d", len(keys), r.Len()) + } + + txn = r.Txn() + for _, k := range keys { + txn.Delete([]byte(k)) + } + r = txn.Commit() + + if r.Len() != 0 { + t.Fatalf("tree len should be zero, got %d", r.Len()) + } +} const datasetSize = 100000 diff --git a/txn.go b/txn.go index bb63eb3..305cc2c 100644 --- a/txn.go +++ b/txn.go @@ -3,7 +3,12 @@ package adaptive -import "strings" +import ( + "bytes" + "strings" + + "github.com/hashicorp/golang-lru/v2/simplelru" +) const defaultModifiedCache = 8192 @@ -25,15 +30,21 @@ type Txn[T any] struct { trackChannels map[chan struct{}]struct{} trackOverflow bool trackMutate bool + + // writable is a cache of writable nodes that have been created during + // the course of the transaction. This allows us to re-use the same + // nodes for further writes and avoid unnecessary copies of nodes that + // have never been exposed outside the transaction. This will only hold + // up to defaultModifiedCache number of entries. + writable *simplelru.LRU[Node[T], any] } // Txn starts a new transaction that can be used to mutate the tree func (t *RadixTree[T]) Txn() *Txn[T] { txn := &Txn[T]{ - size: t.size, - snap: t.root, - tree: t, - trackChannels: t.trachChns, + size: t.size, + snap: t.root, + tree: t, } return txn } @@ -58,80 +69,211 @@ func (t *Txn[T]) TrackMutate(track bool) { t.trackMutate = track } -// trackChannel safely attempts to track the given mutation channel, setting the -// overflow flag if we can no longer track any more. This limits the amount of -// state that will accumulate during a transaction and we have a slower algorithm -// to switch to if we overflow. -func (t *Txn[T]) trackChannel(ch chan struct{}) { - // In overflow, make sure we don't store any more objects. - if t.trackOverflow { - return +// Get is used to look up a specific key, returning +// the value and if it was found +func (t *Txn[T]) Get(k []byte) (T, bool) { + res, found, _ := t.tree.Get(k) + return res, found +} + +func (t *Txn[T]) Insert(key []byte, value T) (T, bool) { + var old int + newRoot, oldVal := t.recursiveInsert(t.tree.root, getTreeKey(key), value, 0, &old) + if t.trackMutate { + t.trackChannel(t.tree.root.getMutateCh()) } + if old == 0 { + t.size++ + } + t.tree.root = newRoot + return oldVal, old == 1 +} - // If this would overflow the state we reject it and set the flag (since - // we aren't tracking everything that's required any longer). - if len(t.trackChannels) >= defaultModifiedCache { - // Mark that we are in the overflow state - t.trackOverflow = true +func (t *Txn[T]) recursiveInsert(node Node[T], key []byte, value T, depth int, old *int) (Node[T], T) { + var zero T - // Clear the map so that the channels can be garbage collected. It is - // safe to do this since we have already overflowed and will be using - // the slow notify algorithm. - t.trackChannels = nil - return + // If we are at a nil node, inject a leaf + if node == nil { + return t.makeLeaf(key, value), zero } - // Create the map on the fly when we need it. - if t.trackChannels == nil { - t.trackChannels = make(map[chan struct{}]struct{}) + if node.isLeaf() { + // This means node is nil + if node.getKeyLen() == 0 { + if t.trackMutate { + t.trackChannel(node.getMutateCh()) + } + return t.makeLeaf(key, value), zero + } } - // Otherwise we are good to track it. - t.trackChannels[ch] = struct{}{} -} + // If we are at a leaf, we need to replace it with a node + if node.isLeaf() { + // Check if we are updating an existing value + nodeKey := node.getKey() + if len(key) == len(nodeKey) && bytes.Equal(nodeKey, key) { + *old = 1 + if t.trackMutate { + t.trackChannel(node.getMutateCh()) + } + return t.makeLeaf(key, value), node.getValue() + } -// Visit all the nodes in the tree under n, and add their mutateChannels to the transaction -// Returns the size of the subtree visited -func (t *Txn[T]) trackChannelsAndCount(n Node[T]) int { - // Count only leaf nodes - leaves := 0 - if n.isLeaf() { - leaves = 1 - } - // Mark this node as being mutated. - if t.trackMutate { - t.trackChannel(n.getMutateCh()) + // New value, we must split the leaf into a node4 + newLeaf2 := t.makeLeaf(key, value) + + nc := t.writeNode(node) + // Determine longest prefix + longestPrefix := longestCommonPrefix[T](node, newLeaf2, depth) + newNode := t.allocNode(node4) + newNode.setPartialLen(uint32(longestPrefix)) + copy(newNode.getPartial()[:], key[depth:depth+min(maxPrefixLen, longestPrefix)]) + + // Add the leafs to the new node4 + newNode = t.addChild(newNode, nc.getKey()[depth+longestPrefix], nc) + newNode = t.addChild(newNode, newLeaf2.getKey()[depth+longestPrefix], newLeaf2) + return newNode, zero } - // Mark its leaf as being mutated, if appropriate. - if t.trackMutate && n.isLeaf() { - t.trackChannel(n.getMutateCh()) + // Check if given node has a prefix + if node.getPartialLen() > 0 { + // Determine if the prefixes differ, since we need to split + prefixDiff := prefixMismatch[T](node, key, len(key), depth) + if prefixDiff >= int(node.getPartialLen()) { + depth += int(node.getPartialLen()) + child, idx := t.findChild(node, key[depth]) + if child != nil { + newChild, val := t.recursiveInsert(child, key, value, depth+1, old) + node.setChild(idx, newChild) + if t.trackMutate { + t.trackChannel(node.getMutateCh()) + } + return node, val + } + + // No child, node goes within us + newLeaf := t.makeLeaf(key, value) + node = t.addChild(node, key[depth], newLeaf) + if t.trackMutate { + t.trackChannel(node.getMutateCh()) + } + return node, zero + } + + // Create a new node + newNode := t.allocNode(node4) + newNode.setPartialLen(uint32(prefixDiff)) + copy(newNode.getPartial()[:], node.getPartial()[:min(maxPrefixLen, prefixDiff)]) + + // Adjust the prefix of the old node + if node.getPartialLen() <= maxPrefixLen { + newNode = t.addChild(newNode, node.getPartial()[prefixDiff], node) + node.setPartialLen(node.getPartialLen() - uint32(prefixDiff+1)) + length := min(maxPrefixLen, int(node.getPartialLen())) + copy(node.getPartial()[:], node.getPartial()[prefixDiff+1:+prefixDiff+1+length]) + } else { + node.setPartialLen(node.getPartialLen() - uint32(prefixDiff+1)) + l := minimum[T](node) + if l == nil { + return node, zero + } + newNode = t.addChild(newNode, l.key[depth+prefixDiff], node) + length := min(maxPrefixLen, int(node.getPartialLen())) + copy(node.getPartial()[:], l.key[depth+prefixDiff+1:depth+prefixDiff+1+length]) + } + if t.trackMutate { + t.trackChannel(node.getMutateCh()) + } + // Insert the new leaf + newLeaf := t.makeLeaf(key, value) + newNode = t.addChild(newNode, key[depth+prefixDiff], newLeaf) + return newNode, zero + } + // Find a child to recurse to + child, idx := t.findChild(node, key[depth]) + if child != nil { + newChild, val := t.recursiveInsert(child, key, value, depth+1, old) + node.setChild(idx, newChild) + if t.trackMutate { + t.trackChannel(node.getMutateCh()) + } + return node, val } - // Recurse on the children - for _, ch := range n.getChildren() { - leaves += t.trackChannelsAndCount(ch) + // No child, node goes within us + newLeaf := t.makeLeaf(key, value) + if t.trackMutate { + t.trackChannel(node.getMutateCh()) } - return leaves + return t.addChild(node, key[depth], newLeaf), zero } -// Get is used to look up a specific key, returning -// the value and if it was found -func (t *Txn[T]) Get(k []byte) (T, bool) { - res, found, _ := t.tree.Get(k) - return res, found +func (t *Txn[T]) Delete(key []byte) (T, bool) { + var zero T + newRoot, l := t.recursiveDelete(t.tree.root, getTreeKey(key), 0) + if t.trackMutate { + t.trackChannel(t.tree.root.getMutateCh()) + } + if newRoot == nil { + newRoot = t.allocNode(leafType) + } + t.tree.root = newRoot + if l != nil { + t.size-- + old := l.getValue() + return old, true + } + return zero, false } -func (t *Txn[T]) Insert(key []byte, value T) T { - oldVal := t.tree.Insert(key, value) - t.size = t.tree.size - return oldVal -} +func (t *Txn[T]) recursiveDelete(node Node[T], key []byte, depth int) (Node[T], Node[T]) { + // Get terminated + if node == nil { + return nil, nil + } + // Handle hitting a leaf node + if isLeaf[T](node) { + if leafMatches(node.getKey(), key) == 0 { + if t.trackMutate { + t.trackChannel(node.getMutateCh()) + } + return nil, node + } + return node, nil + } + + // Bail if the prefix does not match + if node.getPartialLen() > 0 { + prefixLen := checkPrefix(node.getPartial(), int(node.getPartialLen()), key, depth) + if prefixLen != min(maxPrefixLen, int(node.getPartialLen())) { + return node, nil + } + depth += int(node.getPartialLen()) + } + + // Find child node + child, idx := t.findChild(node, key[depth]) + if child == nil { + return nil, nil + } + + // If the child is a leaf, delete from this node + if isLeaf[T](child) { + if leafMatches(child.getKey(), key) == 0 { + if t.trackMutate { + t.trackChannel(child.getMutateCh()) + } + newNode := t.removeChild(node, key[depth]) + return t.writeNode(newNode), child + } + return node, nil + } -func (t *Txn[T]) Delete(key []byte) T { - oldVal := t.tree.Delete(key) - t.size = t.tree.size - return oldVal + // Recurse + newChild, val := t.recursiveDelete(child, key, depth+1) + nClone := t.writeNode(node) + nClone.setChild(idx, t.writeNode(newChild)) + return nClone, val } func (t *Txn[T]) Root() Node[T] { @@ -156,10 +298,18 @@ func (t *Txn[T]) Notify() { // If we've overflowed the tracking state we can't use it in any way and // need to do a full tree compare. if t.trackOverflow { - t.slowNotify() + // TODO Discuss + //t.slowNotify() } else { for ch := range t.trackChannels { - close(ch) + select { + case _, ok := <-ch: + if ok { + close(ch) + } + default: + close(ch) + } } } @@ -182,7 +332,10 @@ func (t *Txn[T]) Commit() *RadixTree[T] { // CommitOnly is used to finalize the transaction and return a new tree, but // does not issue any notifications until Notify is called. func (t *Txn[T]) CommitOnly() *RadixTree[T] { - nt := &RadixTree[T]{t.tree.root, t.size, t.trackChannels} + nt := &RadixTree[T]{t.tree.root, + t.size, + } + t.writable = nil return nt } @@ -192,6 +345,8 @@ func (t *Txn[T]) CommitOnly() *RadixTree[T] { func (t *Txn[T]) slowNotify() { snapIter := t.snap.Iterator() rootIter := t.Root().Iterator() + snapIter.Next() + rootIter.Next() for snapIter.Front() != nil || rootIter.Front() != nil { // If we've exhausted the nodes in the old snapshot, we know // there's nothing remaining to notify. @@ -205,14 +360,21 @@ func (t *Txn[T]) slowNotify() { // know from the loop condition there's something in the old // snapshot. if rootIter.Front() == nil { - close(snapElem.getMutateCh()) + select { + case _, ok := <-snapElem.getMutateCh(): + if ok { + close(snapElem.getMutateCh()) + } + default: + close(snapElem.getMutateCh()) + } snapIter.Next() continue } // Do one string compare so we can check the various conditions // below without repeating the compare. - cmp := strings.Compare(snapIter.Path(), rootIter.Path()) + cmp := strings.Compare(string(getKey(snapIter.GetIterPath())), string(getKey(rootIter.GetIterPath()))) // If the snapshot is behind the node, then we must have deleted // this node during the transaction. @@ -247,10 +409,166 @@ func (t *Txn[T]) LongestPrefix(prefix []byte) ([]byte, T, bool) { // DeletePrefix is used to delete an entire subtree that matches the prefix // This will delete all nodes under that prefix func (t *Txn[T]) DeletePrefix(prefix []byte) bool { - newRoot, ok := t.tree.DeletePrefix(prefix) - if ok { + key := getTreeKey(prefix) + newRoot, numDeletions := t.deletePrefix(t.tree.root, key, 0) + if numDeletions != 0 { t.tree.root = newRoot + t.tree.size = t.tree.size - uint64(numDeletions) t.size = t.tree.size + return true } return false } + +func (t *Txn[T]) deletePrefix(node Node[T], key []byte, depth int) (Node[T], int) { + // Get terminated + if node == nil { + return nil, 0 + } + // Handle hitting a leaf node + if isLeaf[T](node) { + if bytes.HasPrefix(getKey(node.getKey()), getKey(key)) { + t.trackChannel(node.getMutateCh()) + return nil, 1 + } + return node, 0 + } + + // Bail if the prefix does not match + if node.getPartialLen() > 0 { + prefixLen := checkPrefix(node.getPartial(), int(node.getPartialLen()), key, depth) + if prefixLen < min(maxPrefixLen, len(getKey(key))) { + depth += prefixLen + } else { + return node, 0 + } + } + + t.trackChannel(node.getMutateCh()) + + numDel := 0 + + // Recurse on the children + var newChIndxMap = make(map[int]Node[T]) + for idx, ch := range node.getChildren() { + if ch != nil { + newCh, del := t.deletePrefix(ch, key, depth+1) + newChIndxMap[idx] = newCh + numDel += del + } + } + + for idx, ch := range newChIndxMap { + node.setChild(idx, ch) + } + + return node, numDel +} + +func (t *Txn[T]) makeLeaf(key []byte, value T) Node[T] { + // Allocate memory for the leaf node + l := t.allocNode(leafType) + + if l == nil { + return nil + } + + // Set the value and key length + l.setValue(value) + l.setKeyLen(uint32(len(key))) + l.setKey(key) + return l +} + +func (t *Txn[T]) writeNode(n Node[T]) Node[T] { + if t.writable == nil { + lru, err := simplelru.NewLRU[Node[T], any](defaultModifiedCache, nil) + if err != nil { + panic(err) + } + t.writable = lru + } + // If this node has already been modified, we can continue to use it + // during this transaction. We know that we don't need to track it for + // a node update since the node is writable, but if this is for a leaf + // update we track it, in case the initial write to this node didn't + // update the leaf. + if _, ok := t.writable.Get(n); ok { + if t.trackMutate { + t.trackChannel(n.getMutateCh()) + } + return n + } + // Mark this node as being mutated. + if t.trackMutate { + t.trackChannel(n.getMutateCh()) + } + + // Copy the existing node. If you have set forLeafUpdate it will be + // safe to replace this leaf with another after you get your node for + // writing. You MUST replace it, because the channel associated with + // this leaf will be closed when this transaction is committed. + nc := n.clone() + + // Mark this node as writable. + t.writable.Add(nc, nil) + return nc +} + +func (t *Txn[T]) allocNode(ntype nodeType) Node[T] { + var n Node[T] + switch ntype { + case leafType: + n = &NodeLeaf[T]{} + case node4: + n = &Node4[T]{} + case node16: + n = &Node16[T]{} + case node48: + n = &Node48[T]{} + case node256: + n = &Node256[T]{} + default: + panic("Unknown node type") + } + n.setMutateCh(make(chan struct{})) + n.setPartial(make([]byte, maxPrefixLen)) + n.setPartialLen(maxPrefixLen) + return n +} + +// trackChannel safely attempts to track the given mutation channel, setting the +// overflow flag if we can no longer track any more. This limits the amount of +// state that will accumulate during a transaction and we have a slower algorithm +// to switch to if we overflow. +func (t *Txn[T]) trackChannel(ch chan struct{}) { + // In overflow, make sure we don't store any more objects. + if t.trackOverflow { + return + } + + // If this would overflow the state we reject it and set the flag (since + // we aren't tracking everything that's required any longer). + if len(t.trackChannels) >= defaultModifiedCache { + // Mark that we are in the overflow state + t.trackOverflow = true + + // Clear the map so that the channels can be garbage collected. It is + // safe to do this since we have already overflowed and will be using + // the slow notify algorithm. + t.trackChannels = nil + return + } + + // Create the map on the fly when we need it. + if t.trackChannels == nil { + t.trackChannels = make(map[chan struct{}]struct{}) + } + + t.trackChannels[ch] = struct{}{} +} + +// findChild finds the child node pointer based on the given character in the ART tree node. +func (t *Txn[T]) findChild(n Node[T], c byte) (Node[T], int) { + return findChild(n, c) +}