diff --git a/iterator.go b/iterator.go index aa32f14..9f4f2ba 100644 --- a/iterator.go +++ b/iterator.go @@ -151,8 +151,6 @@ func (i *Iterator[T]) SeekPrefixWatch(prefixKey []byte) (watch <-chan struct{}) i.stack = []Node[T]{node} i.node = node - parent := node - for { // Check if the node matches the prefix @@ -196,8 +194,7 @@ func (i *Iterator[T]) SeekPrefixWatch(prefixKey []byte) (watch <-chan struct{}) i.node = node // Move to the next level in the tree - parent = node - watch = parent.getMutateCh() + watch = node.getMutateCh() node = child depth++ } diff --git a/tree_test.go b/tree_test.go index 9453718..821cafa 100644 --- a/tree_test.go +++ b/tree_test.go @@ -89,7 +89,8 @@ func TestARTree_InsertAndSearchWords(t *testing.T) { // optionally, resize scanner's capacity for lines over 64K, see next example lineNumber := 1 for scanner.Scan() { - art, _, _ = art.Insert(scanner.Bytes(), lineNumber) + line := scanner.Text() + art, _, _ = art.Insert([]byte(line), lineNumber) lineNumber += 1 lines = append(lines, scanner.Text()) } diff --git a/txn.go b/txn.go index d532514..5c28c53 100644 --- a/txn.go +++ b/txn.go @@ -117,11 +117,14 @@ func (t *Txn[T]) recursiveInsert(node Node[T], key []byte, value T, depth int, o node.processLazyRef() + oldRef := node + if node.isLeaf() { // This means node is nil if node.getKeyLen() == 0 { nl := t.makeLeaf(key, value) nl.processLazyRef() + node.incrementLazyRefCount(-1) t.trackChannel(node) return nl, zero } @@ -146,7 +149,14 @@ func (t *Txn[T]) recursiveInsert(node Node[T], key []byte, value T, depth int, o doClone := node.getRefCount() > 1 if doClone { + oldRef.incrementLazyRefCount(-1) node = t.writeNode(node) + } else { + defer func() { + oldRef.incrementLazyRefCount(-1) + oldRef.processLazyRef() + }() + node.incrementLazyRefCount(1) } // Determine longest prefix @@ -169,13 +179,19 @@ func (t *Txn[T]) recursiveInsert(node Node[T], key []byte, value T, depth int, o return newNode, zero } - oldRef := node oldRef.processLazyRef() // Check if given node has a prefix if node.getPartialLen() > 0 { doClone := node.getRefCount() > 1 if doClone { + oldRef.incrementLazyRefCount(-1) node = t.writeNode(node) + } else { + defer func() { + oldRef.incrementLazyRefCount(-1) + oldRef.processLazyRef() + }() + node.incrementLazyRefCount(1) } // Determine if the prefixes differ, since we need to split prefixDiff := prefixMismatch[T](node, key, len(key), depth) @@ -243,7 +259,14 @@ func (t *Txn[T]) recursiveInsert(node Node[T], key []byte, value T, depth int, o doClone := node.getRefCount() > 1 if doClone { + oldRef.incrementLazyRefCount(-1) node = t.writeNode(node) + } else { + defer func() { + oldRef.incrementLazyRefCount(-1) + oldRef.processLazyRef() + }() + node.incrementLazyRefCount(1) } if depth < len(key) { @@ -306,8 +329,6 @@ func (t *Txn[T]) recursiveDelete(node Node[T], key []byte, depth int) (Node[T], return nil, nil } - doClone := node.getRefCount() > 1 - node.processLazyRef() // Handle hitting a leaf node @@ -319,7 +340,7 @@ func (t *Txn[T]) recursiveDelete(node Node[T], key []byte, depth int) (Node[T], return node, nil } - node.incrementLazyRefCount(1) + oldRef := node // Bail if the prefix does not match if node.getPartialLen() > 0 { @@ -333,19 +354,23 @@ func (t *Txn[T]) recursiveDelete(node Node[T], key []byte, depth int) (Node[T], // Find child node child, idx := t.findChild(node, key[depth]) if child == nil { - node.incrementLazyRefCount(-1) return nil, nil } - oldRef := node - // Recurse newChild, val := t.recursiveDelete(child, key, depth+1) if newChild != child { + doClone := node.getRefCount() > 1 + if doClone { + oldRef.incrementLazyRefCount(-1) node = t.writeNode(node) } else { + defer func() { + oldRef.incrementLazyRefCount(-1) + }() + node.incrementLazyRefCount(1) t.trackChannel(oldRef) } node.setChild(idx, newChild) @@ -353,20 +378,10 @@ func (t *Txn[T]) recursiveDelete(node Node[T], key []byte, depth int) (Node[T], if newChild == nil { t.trackChannel(child) - - if doClone { - node = t.writeNode(node) - } else { - t.trackChannel(oldRef) - } - if doClone { - oldRef.incrementLazyRefCount(-1) - } node = t.removeChild(node, key[depth]) } oldRef.processLazyRef() - oldRef.incrementLazyRefCount(-1) return node, val }