From d65c36c5730fbe05be1bfde9279242cb3fcbf44f Mon Sep 17 00:00:00 2001 From: Andrei Borzenkov Date: Fri, 11 Oct 2024 12:58:08 +0400 Subject: [PATCH] Inline color field into map constructors --- src/PersistentOrderedMap.mo | 194 +++++++++++++++++------------- test/PersistentOrderedMap.test.mo | 41 ++++--- 2 files changed, 130 insertions(+), 105 deletions(-) diff --git a/src/PersistentOrderedMap.mo b/src/PersistentOrderedMap.mo index 1fa00545..b3e8d434 100644 --- a/src/PersistentOrderedMap.mo +++ b/src/PersistentOrderedMap.mo @@ -38,14 +38,12 @@ import O "Order"; module { - /// Node color: Either red (`#R`) or black (`#B`). - public type Color = { #R; #B }; - /// Red-black tree of nodes with key-value entries, ordered by the keys. /// The keys have the generic type `K` and the values the generic type `V`. /// Leaves are considered implicitly black. public type Map = { - #node : (Color, Map, (K, V), Map); + #red : (Map, (K, V), Map); + #black : (Map, (K, V), Map); #leaf }; @@ -351,7 +349,11 @@ module { trees := ts; ?xy }; - case (?(#tr(#node(_, l, xy, r)), ts)) { + case (?(#tr(#red(l, xy, r)), ts)) { + trees := mapTraverser(l, xy, r, ts); + next() + }; + case (?(#tr(#black(l, xy, r)), ts)) { trees := mapTraverser(l, xy, r, ts); next() } @@ -467,8 +469,11 @@ module { func mapRec(m : Map) : Map { switch m { case (#leaf) { #leaf }; - case (#node(c, l, xy, r)) { - #node(c, mapRec l, (xy.0, f xy), mapRec r) // TODO: try destination-passing style to avoid non tail-call recursion + case (#red(l, xy, r)) { + #red(mapRec l, (xy.0, f xy), mapRec r) + }; + case (#black(l, xy, r)) { + #black(mapRec l, (xy.0, f xy), mapRec r) }; } }; @@ -497,7 +502,10 @@ module { public func size(t : Map) : Nat { switch t { case (#leaf) { 0 }; - case (#node(_, l, _, r)) { + case (#red(l, _, r)) { + size(l) + size(r) + 1 + }; + case (#black(l, _, r)) { size(l) + size(r) + 1 } } @@ -538,7 +546,12 @@ module { { switch (rbMap) { case (#leaf) { base }; - case (#node(_, l, (k, v), r)) { + case (#red(l, (k, v), r)) { + let left = foldLeft(l, base, combine); + let middle = combine(k, v, left); + foldLeft(r, middle, combine) + }; + case (#black(l, (k, v), r)) { let left = foldLeft(l, base, combine); let middle = combine(k, v, left); foldLeft(r, middle, combine) @@ -581,7 +594,12 @@ module { { switch (rbMap) { case (#leaf) { base }; - case (#node(_, l, (k, v), r)) { + case (#red(l, (k, v), r)) { + let right = foldRight(r, base, combine); + let middle = combine(k, v, right); + foldRight(l, middle, combine) + }; + case (#black(l, (k, v), r)) { let right = foldRight(r, base, combine); let middle = combine(k, v, right); foldRight(l, middle, combine) @@ -616,7 +634,14 @@ module { public func get(t : Map, compare : (K, K) -> O.Order, x : K) : ?V { switch t { case (#leaf) { null }; - case (#node(_c, l, xy, r)) { + case (#red(l, xy, r)) { + switch (compare(x, xy.0)) { + case (#less) { get(l, compare, x) }; + case (#equal) { ?xy.1 }; + case (#greater) { get(r, compare, x) } + } + }; + case (#black(l, xy, r)) { switch (compare(x, xy.0)) { case (#less) { get(l, compare, x) }; case (#equal) { ?xy.1 }; @@ -628,8 +653,8 @@ module { func redden(t : Map) : Map { switch t { - case (#node (#B, l, xy, r)) { - (#node (#R, l, xy, r)) + case (#black (l, xy, r)) { + (#red (l, xy, r)) }; case _ { Debug.trap "RBTree.red" @@ -639,44 +664,40 @@ module { func lbalance(left : Map, xy : (K,V), right : Map) : Map { switch (left, right) { - case (#node(#R, #node(#R, l1, xy1, r1), xy2, r2), r) { - #node( - #R, - #node(#B, l1, xy1, r1), + case (#red(#red(l1, xy1, r1), xy2, r2), r) { + #red( + #black(l1, xy1, r1), xy2, - #node(#B, r2, xy, r)) + #black(r2, xy, r)) }; - case (#node(#R, l1, xy1, #node(#R, l2, xy2, r2)), r) { - #node( - #R, - #node(#B, l1, xy1, l2), + case (#red(l1, xy1, #red(l2, xy2, r2)), r) { + #red( + #black(l1, xy1, l2), xy2, - #node(#B, r2, xy, r)) + #black(r2, xy, r)) }; case _ { - #node(#B, left, xy, right) + #black(left, xy, right) } } }; func rbalance(left : Map, xy : (K,V), right : Map) : Map { switch (left, right) { - case (l, #node(#R, l1, xy1, #node(#R, l2, xy2, r2))) { - #node( - #R, - #node(#B, l, xy, l1), + case (l, #red(l1, xy1, #red(l2, xy2, r2))) { + #red( + #black(l, xy, l1), xy1, - #node(#B, l2, xy2, r2)) + #black(l2, xy2, r2)) }; - case (l, #node(#R, #node(#R, l1, xy1, r1), xy2, r2)) { - #node( - #R, - #node(#B, l, xy, l1), + case (l, #red(#red(l1, xy1, r1), xy2, r2)) { + #red( + #black(l, xy, l1), xy1, - #node(#B, r1, xy2, r2)) + #black(r1, xy2, r2)) }; case _ { - #node(#B, left, xy, right) + #black(left, xy, right) }; } }; @@ -694,9 +715,9 @@ module { func ins(tree : Map) : Map { switch tree { case (#leaf) { - #node(#R, #leaf, (key,val), #leaf) + #red(#leaf, (key,val), #leaf) }; - case (#node(#B, left, xy, right)) { + case (#black(left, xy, right)) { switch (compare (key, xy.0)) { case (#less) { lbalance(ins left, xy, right) @@ -706,29 +727,29 @@ module { }; case (#equal) { let newVal = onClash({ new = val; old = xy.1 }); - #node(#B, left, (key,newVal), right) + #black(left, (key,newVal), right) } } }; - case (#node(#R, left, xy, right)) { + case (#red(left, xy, right)) { switch (compare (key, xy.0)) { case (#less) { - #node(#R, ins left, xy, right) + #red(ins left, xy, right) }; case (#greater) { - #node(#R, left, xy, ins right) + #red(left, xy, ins right) }; case (#equal) { let newVal = onClash { new = val; old = xy.1 }; - #node(#R, left, (key,newVal), right) + #red(left, (key,newVal), right) } } } }; }; switch (ins m) { - case (#node(#R, left, xy, right)) { - #node(#B, left, xy, right); + case (#red(left, xy, right)) { + #black(left, xy, right); }; case other { other }; }; @@ -761,19 +782,18 @@ module { func balLeft(left : Map, xy : (K,V), right : Map) : Map { switch (left, right) { - case (#node(#R, l1, xy1, r1), r) { - #node( - #R, - #node(#B, l1, xy1, r1), + case (#red(l1, xy1, r1), r) { + #red( + #black(l1, xy1, r1), xy, r) }; - case (_, #node(#B, l2, xy2, r2)) { - rbalance(left, xy, #node(#R, l2, xy2, r2)) + case (_, #black(l2, xy2, r2)) { + rbalance(left, xy, #red(l2, xy2, r2)) }; - case (_, #node(#R, #node(#B, l2, xy2, r2), xy3, r3)) { - #node(#R, - #node(#B, left, xy, l2), + case (_, #red(#black(l2, xy2, r2), xy3, r3)) { + #red( + #black(left, xy, l2), xy2, rbalance(r2, xy3, redden r3)) }; @@ -783,20 +803,20 @@ module { func balRight(left : Map, xy : (K,V), right : Map) : Map { switch (left, right) { - case (l, #node(#R, l1, xy1, r1)) { - #node(#R, + case (l, #red(l1, xy1, r1)) { + #red( l, xy, - #node(#B, l1, xy1, r1)) + #black(l1, xy1, r1)) }; - case (#node(#B, l1, xy1, r1), r) { - lbalance(#node(#R, l1, xy1, r1), xy, r); + case (#black(l1, xy1, r1), r) { + lbalance(#red(l1, xy1, r1), xy, r); }; - case (#node(#R, l1, xy1, #node(#B, l2, xy2, r2)), r3) { - #node(#R, + case (#red(l1, xy1, #black(l2, xy2, r2)), r3) { + #red( lbalance(redden l1, xy1, l2), xy2, - #node(#B, r2, xy, r3)) + #black(r2, xy, r3)) }; case _ { Debug.trap "balRight" }; } @@ -806,40 +826,39 @@ module { switch (left, right) { case (#leaf, _) { right }; case (_, #leaf) { left }; - case (#node (#R, l1, xy1, r1), - #node (#R, l2, xy2, r2)) { + case (#red (l1, xy1, r1), + #red (l2, xy2, r2)) { switch (append (r1, l2)) { - case (#node (#R, l3, xy3, r3)) { - #node( - #R, - #node(#R, l1, xy1, l3), + case (#red (l3, xy3, r3)) { + #red( + #red(l1, xy1, l3), xy3, - #node(#R, r3, xy2, r2)) + #red(r3, xy2, r2)) }; case r1l2 { - #node(#R, l1, xy1, #node(#R, r1l2, xy2, r2)) + #red(l1, xy1, #red(r1l2, xy2, r2)) } } }; - case (t1, #node(#R, l2, xy2, r2)) { - #node(#R, append(t1, l2), xy2, r2) + case (t1, #red(l2, xy2, r2)) { + #red(append(t1, l2), xy2, r2) }; - case (#node(#R, l1, xy1, r1), t2) { - #node(#R, l1, xy1, append(r1, t2)) + case (#red(l1, xy1, r1), t2) { + #red(l1, xy1, append(r1, t2)) }; - case (#node(#B, l1, xy1, r1), #node (#B, l2, xy2, r2)) { + case (#black(l1, xy1, r1), #black (l2, xy2, r2)) { switch (append (r1, l2)) { - case (#node (#R, l3, xy3, r3)) { - #node(#R, - #node(#B, l1, xy1, l3), + case (#red (l3, xy3, r3)) { + #red( + #black(l1, xy1, l3), xy3, - #node(#B, r3, xy2, r2)) + #black(r3, xy2, r2)) }; case r1l2 { balLeft ( l1, xy1, - #node(#B, r1l2, xy2, r2) + #black(r1l2, xy2, r2) ) } } @@ -857,22 +876,22 @@ module { case (#less) { let newLeft = del left; switch left { - case (#node(#B, _, _, _)) { + case (#black(_, _, _)) { balLeft(newLeft, xy, right) }; case _ { - #node(#R, newLeft, xy, right) + #red(newLeft, xy, right) } } }; case (#greater) { let newRight = del right; switch right { - case (#node(#B, _, _, _)) { + case (#black(_, _, _)) { balRight(left, xy, newRight) }; case _ { - #node(#R, left, xy, newRight) + #red(left, xy, newRight) } } }; @@ -887,14 +906,17 @@ module { case (#leaf) { tree }; - case (#node(_, left, xy, right)) { + case (#red(left, xy, right)) { + delNode(left, xy, right) + }; + case (#black(left, xy, right)) { delNode(left, xy, right) } }; }; switch (del(tree)) { - case (#node(#R, left, xy, right)) { - (#node(#B, left, xy, right), y0); + case (#red(left, xy, right)) { + (#black(left, xy, right), y0); }; case other { (other, y0) }; }; diff --git a/test/PersistentOrderedMap.test.mo b/test/PersistentOrderedMap.test.mo index 58e7b41d..0e6794c6 100644 --- a/test/PersistentOrderedMap.test.mo +++ b/test/PersistentOrderedMap.test.mo @@ -31,24 +31,24 @@ func checkMap(rbMap : Map.Map) { }; func blackDepth(node : Map.Map) : Nat { + func checkNode(left : Map.Map, key : Nat, right : Map.Map) : Nat { + checkKey(left, func(x) { x < key }); + checkKey(right, func(x) { x > key }); + let leftBlacks = blackDepth(left); + let rightBlacks = blackDepth(right); + assert (leftBlacks == rightBlacks); + leftBlacks + }; switch node { case (#leaf) 0; - case (#node(color, left, (key, _), right)) { - checkKey(left, func(x) { x < key }); - checkKey(right, func(x) { x > key }); - let leftBlacks = blackDepth(left); - let rightBlacks = blackDepth(right); - assert (leftBlacks == rightBlacks); - switch color { - case (#R) { - assert (not isRed(left)); - assert (not isRed(right)); - leftBlacks - }; - case (#B) { - leftBlacks + 1 - } - } + case (#red(left, (key, _), right)) { + let leftBlacks = checkNode(left, key, right); + assert (not isRed(left)); + assert (not isRed(right)); + leftBlacks + }; + case (#black(left, (key, _), right)) { + checkNode(left, key, right) + 1 } } }; @@ -56,15 +56,18 @@ func blackDepth(node : Map.Map) : Nat { func isRed(node : Map.Map) : Bool { switch node { - case (#leaf) false; - case (#node(color, _, _, _)) color == #R + case (#red(_, _, _)) true; + case _ false } }; func checkKey(node : Map.Map, isValid : Nat -> Bool) { switch node { case (#leaf) {}; - case (#node(_, _, (key, _), _)) { + case (#red( _, (key, _), _)) { + assert (isValid(key)) + }; + case (#black( _, (key, _), _)) { assert (isValid(key)) } }