Skip to content

Commit

Permalink
Fix Test and Add tests for Track Mutate + Add LRU Cache for writeNode (
Browse files Browse the repository at this point in the history
…#6)

* reverse iterator init

* some fixes

* fix reverse iterator seaklowerbound

* fix tests

* fix track channels

* longest prefix on txn

* some fixes

* revisit

* todo revisit

* major code refactor

* some minor fixes

* add lru

* fix bugs

* added walk func

* add more tests

* some prog

* some progress track mutate

* some progress

* fix tests
  • Loading branch information
absolutelightning authored May 21, 2024
1 parent 0a505a5 commit 231ca20
Show file tree
Hide file tree
Showing 15 changed files with 1,254 additions and 418 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go 1.21

require (
github.com/hashicorp/go-uuid v1.0.3
github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/stretchr/testify v1.9.0
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
)
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8=
github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
Expand Down
64 changes: 11 additions & 53 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,44 +28,6 @@ func leafMatches(nodeKey []byte, key []byte) int {
return bytes.Compare(nodeKey, key)
}

func (t *RadixTree[T]) makeLeaf(key []byte, value T) Node[T] {
// Allocate memory for the leaf node
l := t.allocNode(leafType)

if l == nil {
return nil
}

// Set the value and key length
l.setValue(value)
l.setKeyLen(uint32(len(key)))
l.setKey(key)
return l
}

func (t *RadixTree[T]) allocNode(ntype nodeType) Node[T] {
var n Node[T]
switch ntype {
case leafType:
n = &NodeLeaf[T]{}
case node4:
n = &Node4[T]{}
case node16:
n = &Node16[T]{}
case node48:
n = &Node48[T]{}
case node256:
n = &Node256[T]{}
default:
panic("Unknown node type")
}
n.setMutateCh(make(chan struct{}))
n.setPartial(make([]byte, maxPrefixLen))
n.setPartialLen(maxPrefixLen)
t.trachChns[n.getMutateCh()] = struct{}{}
return n
}

// longestCommonPrefix finds the length of the longest common prefix between two leaf nodes.
func longestCommonPrefix[T any](l1, l2 Node[T], depth int) int {
maxCmp := len(l2.getKey()) - depth
Expand All @@ -82,7 +44,7 @@ func longestCommonPrefix[T any](l1, l2 Node[T], depth int) int {
}

// addChild adds a child node to the parent node.
func (t *RadixTree[T]) addChild(n Node[T], c byte, child Node[T]) Node[T] {
func (t *Txn[T]) addChild(n Node[T], c byte, child Node[T]) Node[T] {
switch n.getArtNodeType() {
case node4:
return t.addChild4(n, c, child)
Expand All @@ -98,7 +60,7 @@ func (t *RadixTree[T]) addChild(n Node[T], c byte, child Node[T]) Node[T] {
}

// addChild4 adds a child node to a node4.
func (t *RadixTree[T]) addChild4(n Node[T], c byte, child Node[T]) Node[T] {
func (t *Txn[T]) addChild4(n Node[T], c byte, child Node[T]) Node[T] {
if n.getNumChildren() < 4 {
idx := sort.Search(int(n.getNumChildren()), func(i int) bool {
return n.getKeyAtIdx(i) > c
Expand All @@ -124,7 +86,7 @@ func (t *RadixTree[T]) addChild4(n Node[T], c byte, child Node[T]) Node[T] {
}

// addChild16 adds a child node to a node16.
func (t *RadixTree[T]) addChild16(n Node[T], c byte, child Node[T]) Node[T] {
func (t *Txn[T]) addChild16(n Node[T], c byte, child Node[T]) Node[T] {
if n.getNumChildren() < 16 {
idx := sort.Search(int(n.getNumChildren()), func(i int) bool {
return n.getKeyAtIdx(i) > c
Expand Down Expand Up @@ -152,7 +114,7 @@ func (t *RadixTree[T]) addChild16(n Node[T], c byte, child Node[T]) Node[T] {
}

// addChild48 adds a child node to a node48.
func (t *RadixTree[T]) addChild48(n Node[T], c byte, child Node[T]) Node[T] {
func (t *Txn[T]) addChild48(n Node[T], c byte, child Node[T]) Node[T] {
if n.getNumChildren() < 48 {
pos := 0
for n.getChild(pos) != nil {
Expand All @@ -175,14 +137,14 @@ func (t *RadixTree[T]) addChild48(n Node[T], c byte, child Node[T]) Node[T] {
}

// addChild256 adds a child node to a node256.
func (t *RadixTree[T]) addChild256(n Node[T], c byte, child Node[T]) Node[T] {
func (t *Txn[T]) addChild256(n Node[T], c byte, child Node[T]) Node[T] {
n.setNumChildren(n.getNumChildren() + 1)
n.setChild(int(c), child)
return n
}

// copyHeader copies header information from src to dest node.
func (t *RadixTree[T]) copyHeader(dest, src Node[T]) {
func (t *Txn[T]) copyHeader(dest, src Node[T]) {
dest.setNumChildren(src.getNumChildren())
dest.setPartialLen(src.getPartialLen())
length := min(maxPrefixLen, int(src.getPartialLen()))
Expand Down Expand Up @@ -308,10 +270,6 @@ func isLeaf[T any](node Node[T]) bool {
return node.isLeaf()
}

// findChild finds the child node pointer based on the given character in the ART tree node.
func (t *RadixTree[T]) findChild(n Node[T], c byte) (Node[T], int) {
return findChild(n, c)
}
func findChild[T any](n Node[T], c byte) (Node[T], int) {
switch n.getArtNodeType() {
case node4:
Expand Down Expand Up @@ -360,7 +318,7 @@ func getKey(key []byte) []byte {
return key[1 : len(key)-1]
}

func (t *RadixTree[T]) removeChild(n Node[T], c byte) Node[T] {
func (t *Txn[T]) removeChild(n Node[T], c byte) Node[T] {
switch n.getArtNodeType() {
case node4:
return t.removeChild4(n.(*Node4[T]), c)
Expand All @@ -375,7 +333,7 @@ func (t *RadixTree[T]) removeChild(n Node[T], c byte) Node[T] {
}
}

func (t *RadixTree[T]) removeChild4(n *Node4[T], c byte) Node[T] {
func (t *Txn[T]) removeChild4(n *Node4[T], c byte) Node[T] {
pos := sort.Search(int(n.numChildren), func(i int) bool {
return n.keys[i] >= c
})
Expand Down Expand Up @@ -412,7 +370,7 @@ func (t *RadixTree[T]) removeChild4(n *Node4[T], c byte) Node[T] {
return n
}

func (t *RadixTree[T]) removeChild16(n *Node16[T], c byte) Node[T] {
func (t *Txn[T]) removeChild16(n *Node16[T], c byte) Node[T] {
pos := sort.Search(int(n.numChildren), func(i int) bool {
return n.keys[i] >= c
})
Expand All @@ -432,7 +390,7 @@ func (t *RadixTree[T]) removeChild16(n *Node16[T], c byte) Node[T] {
return n
}

func (t *RadixTree[T]) removeChild48(n *Node48[T], c uint8) Node[T] {
func (t *Txn[T]) removeChild48(n *Node48[T], c uint8) Node[T] {
pos := n.keys[c]
n.keys[c] = 0
n.children[pos-1] = nil
Expand All @@ -455,7 +413,7 @@ func (t *RadixTree[T]) removeChild48(n *Node48[T], c uint8) Node[T] {
return n
}

func (t *RadixTree[T]) removeChild256(n *Node256[T], c uint8) Node[T] {
func (t *Txn[T]) removeChild256(n *Node256[T], c uint8) Node[T] {
n.children[c] = nil
n.numChildren--

Expand Down
17 changes: 13 additions & 4 deletions iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ type Iterator[T any] struct {
lowerBound bool
reverseLowerBound bool
seenMismatch bool
iterPath []byte
}

func (i *Iterator[T]) GetIterPath() []byte {
return i.iterPath
}

// Front returns the current node that has been iterated to.
Expand All @@ -34,6 +39,7 @@ func (i *Iterator[T]) Next() ([]byte, T, bool) {
var zero T

if len(i.stack) == 0 {
i.pos = nil
return nil, zero, false
}

Expand All @@ -54,15 +60,13 @@ func (i *Iterator[T]) Next() ([]byte, T, bool) {
leafCh := currentNode.(*NodeLeaf[T])
if i.lowerBound {
if bytes.Compare(getKey(leafCh.key), getKey(i.path)) >= 0 {
i.pos = leafCh
return getKey(leafCh.key), leafCh.value, true
}
continue
}
if len(i.Path()) >= 2 && !leafCh.matchPrefix([]byte(i.Path())) {
continue
}
i.pos = leafCh
return getKey(leafCh.key), leafCh.value, true
case node4:
n4 := currentNode.(*Node4[T])
Expand All @@ -77,6 +81,7 @@ func (i *Iterator[T]) Next() ([]byte, T, bool) {
newStack[0] = child
i.stack = newStack
}
i.iterPath = append(i.iterPath, n4.getPartial()[:n4.getPartialLen()]...)
case node16:
n16 := currentNode.(*Node16[T])
for itr := 15; itr >= 0; itr-- {
Expand All @@ -90,6 +95,7 @@ func (i *Iterator[T]) Next() ([]byte, T, bool) {
newStack[0] = child
i.stack = newStack
}
i.iterPath = append(i.iterPath, n16.getPartial()[:n16.getPartialLen()]...)
case node48:
n48 := currentNode.(*Node48[T])
for itr := 0; itr < 256; itr++ {
Expand All @@ -107,6 +113,7 @@ func (i *Iterator[T]) Next() ([]byte, T, bool) {
newStack[0] = child
i.stack = newStack
}
i.iterPath = append(i.iterPath, n48.getPartial()[:n48.getPartialLen()]...)
case node256:
n256 := currentNode.(*Node256[T])
for itr := 255; itr >= 0; itr-- {
Expand All @@ -120,14 +127,15 @@ func (i *Iterator[T]) Next() ([]byte, T, bool) {
newStack[0] = child
i.stack = newStack
}
i.iterPath = append(i.iterPath, n256.getPartial()[:n256.getPartialLen()]...)
}
}
i.pos = nil
return nil, zero, false
}

func (i *Iterator[T]) SeekPrefixWatch(prefixKey []byte) (watch <-chan struct{}) {
// Start from the node node
// Start from the node
node := i.node
watch = node.getMutateCh()

Expand All @@ -142,6 +150,7 @@ func (i *Iterator[T]) SeekPrefixWatch(prefixKey []byte) (watch <-chan struct{})
// Check if the node matches the prefix
i.stack = []Node[T]{node}
i.node = node
watch = node.getMutateCh()

if node.isLeaf() {
return
Expand Down Expand Up @@ -178,7 +187,7 @@ func (i *Iterator[T]) SeekPrefixWatch(prefixKey []byte) (watch <-chan struct{})
watch = node.getMutateCh()
depth++
}
return watch
return
}

func (i *Iterator[T]) SeekPrefix(prefixKey []byte) {
Expand Down
6 changes: 3 additions & 3 deletions iterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ func TestIterateLowerBoundFuzz(t *testing.T) {
// the same list as filtering all sorted keys that are lower.

radixAddAndScan := func(newKey, searchKey readableString) []string {
r.Insert([]byte(newKey), "")
r, _, _ = r.Insert([]byte(newKey), "")

t.Log("NewKey: ", newKey, "SearchKey: ", searchKey)

// Now iterate the tree from searchKey to the end
it := r.root.Iterator()
it := r.Root().Iterator()
var result []string
it.SeekLowerBound([]byte(searchKey))
for {
Expand Down Expand Up @@ -320,7 +320,7 @@ func TestIterateLowerBound(t *testing.T) {
// Insert keys
for _, k := range test.keys {
var ok bool
r.Insert([]byte(k), nil)
r, _, _ = r.Insert([]byte(k), nil)
if ok {
t.Fatalf("duplicate key %s in keys", k)
}
Expand Down
1 change: 1 addition & 0 deletions node_16.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ func (n *Node16[T]) clone() Node[T] {
numChildren: n.getNumChildren(),
partial: n.getPartial(),
}
newNode.mutateCh = make(chan struct{})
copy(newNode.keys[:], n.keys[:])
copy(newNode.children[:], n.children[:])
nodeT := Node[T](newNode)
Expand Down
1 change: 1 addition & 0 deletions node_256.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ func (n *Node256[T]) clone() Node[T] {
numChildren: n.getNumChildren(),
partial: n.getPartial(),
}
newNode.mutateCh = make(chan struct{})
copy(newNode.children[:], n.children[:])
return newNode
}
Expand Down
1 change: 1 addition & 0 deletions node_4.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ func (n *Node4[T]) clone() Node[T] {
numChildren: n.getNumChildren(),
partial: n.getPartial(),
}
newNode.mutateCh = make(chan struct{})
copy(newNode.keys[:], n.keys[:])
copy(newNode.children[:], n.children[:])
return newNode
Expand Down
1 change: 1 addition & 0 deletions node_48.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ func (n *Node48[T]) clone() Node[T] {
numChildren: n.getNumChildren(),
partial: n.getPartial(),
}
newNode.mutateCh = make(chan struct{})
copy(newNode.keys[:], n.keys[:])
copy(newNode.children[:], n.children[:])
return newNode
Expand Down
1 change: 1 addition & 0 deletions node_leaf.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ func (n *NodeLeaf[T]) clone() Node[T] {
key: make([]byte, len(n.getKey())),
value: n.getValue(),
}
newNode.mutateCh = make(chan struct{})
copy(newNode.key[:], n.key[:])
nodeT := Node[T](newNode)
return nodeT
Expand Down
2 changes: 1 addition & 1 deletion path_iter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestPathIterator(t *testing.T) {
"zipzap",
}
for _, k := range keys {
_ = r.Insert([]byte(k), nil)
r, _, _ = r.Insert([]byte(k), nil)
}
if int(r.size) != len(keys) {
t.Fatalf("bad len: %v %v", r.size, len(keys))
Expand Down
16 changes: 8 additions & 8 deletions reverse_iterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ func TestReverseIterator_SeekReverseLowerBoundFuzz(t *testing.T) {
// produces the same list as filtering all sorted keys that are bigger.

radixAddAndScan := func(newKey, searchKey readableString) []string {
r.Insert([]byte(newKey), nil)
r, _, _ = r.Insert([]byte(newKey), nil)

// Now iterate the tree from searchKey to the beginning
it := r.root.ReverseIterator()
it := r.Root().ReverseIterator()
var result []string
it.SeekReverseLowerBound([]byte(searchKey))
for {
Expand Down Expand Up @@ -305,7 +305,7 @@ func TestReverseIterator_SeekLowerBound(t *testing.T) {
// Insert keys
for _, k := range test.keys {
var ok bool
r.Insert([]byte(k), nil)
r, _, _ = r.Insert([]byte(k), nil)
if ok {
t.Fatalf("duplicate key %s in keys", k)
}
Expand Down Expand Up @@ -361,7 +361,7 @@ func TestReverseIterator_SeekPrefix(t *testing.T) {

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
it := r.root.ReverseIterator()
it := r.Root().ReverseIterator()
it.SeekPrefix([]byte(c.prefix))

if c.expectResult && it.i.node == nil {
Expand All @@ -382,10 +382,10 @@ func TestReverseIterator_SeekPrefixWatch(t *testing.T) {

// Create tree
r := NewRadixTree[any]()
r.Insert(key, nil)
r, _, _ = r.Insert(key, nil)

// Find mutate channel
it := r.root.ReverseIterator()
it := r.Root().ReverseIterator()
ch := it.SeekPrefixWatch(key)

// Change prefix
Expand All @@ -406,10 +406,10 @@ func TestReverseIterator_Previous(t *testing.T) {
r := NewRadixTree[any]()
keys := []string{"001", "002", "005", "010", "100"}
for _, k := range keys {
r.Insert([]byte(k), nil)
r, _, _ = r.Insert([]byte(k), nil)
}

it := r.root.ReverseIterator()
it := r.Root().ReverseIterator()

for i := len(keys) - 1; i >= 0; i-- {
got, _, _ := it.Previous()
Expand Down
Loading

0 comments on commit 231ca20

Please sign in to comment.