diff --git a/src/PersistentOrderedSet.mo b/src/PersistentOrderedSet.mo index 0ed84fa0..d28be44e 100644 --- a/src/PersistentOrderedSet.mo +++ b/src/PersistentOrderedSet.mo @@ -1,22 +1,41 @@ -/// Stable ordered set implemented as a red-black tree. Currently built on top of PersistentOrderedMap by storing () values. +/// Stable ordered set implemented as a red-black tree. +/// +/// A red-black tree is a balanced binary search tree ordered by the elements. +/// +/// The tree data structure internally colors each of its nodes either red or black, +/// and uses this information to balance the tree during the modifying operations. /// /// Performance: /// * Runtime: `O(log(n))` worst case cost per insertion, removal, and retrieval operation. /// * Space: `O(n)` for storing the entire tree. /// `n` denotes the number of elements (i.e. nodes) stored in the tree. /// +/// Credits: +/// +/// The core of this implementation is derived from: +/// +/// * Ken Friis Larsen's [RedBlackMap.sml](https://github.com/kfl/mosml/blob/master/src/mosmllib/Redblackmap.sml), which itself is based on: +/// * Stefan Kahrs, "Red-black trees with types", Journal of Functional Programming, 11(4): 425-432 (2001), [version 1 in web appendix](http://www.cs.ukc.ac.uk/people/staff/smk/redblack/rb.html). /// The set operations implementation is derived from: /// Tobias Nipkow's "Functional Data Structures and Algorithms", 10: 117-125 (2024). - -import Map "PersistentOrderedMap"; +import Debug "Debug"; import I "Iter"; +import List "List"; import Nat "Nat"; -import O "Order"; import Option "Option"; +import O "Order"; module { - public type Set = Map.Map; + /// Node color: Either red (`#R`) or black (`#B`). + public type Color = { #R; #B }; + + /// Red-black tree of nodes with ordered set elements. + /// Leaves are considered implicitly black. + public type Set = { + #node : (Color, Set, T, Set); + #leaf + }; /// Opertaions on `Set`, that require a comparator. /// @@ -26,7 +45,6 @@ module { /// `SetOps` contains methods that require `compare` internally: /// operations that may reshape a `Set` or should find something. public class SetOps(compare : (T, T) -> O.Order) { - let mapOps = Map.MapOps(compare); /// Returns a new Set, containing all entries given by the iterator `i`. /// If there are multiple identical entries only one is taken. @@ -86,7 +104,7 @@ module { /// assuming that the `compare` function implements an `O(1)` comparison. /// /// Note: Creates `O(log(n))` temporary objects that will be collected as garbage. - public func put(rbSet : Set, value : T) : Set = mapOps.put<()>(rbSet, value, ()); + public func put(rbSet : Set, value : T) : Set = Internal.put(rbSet, compare, value); /// Deletes the value `value` from the `rbSet`. Has no effect if `value` is not /// present in the set. Returns modified set. @@ -112,7 +130,7 @@ module { /// assuming that the `compare` function implements an `O(1)` comparison. /// /// Note: Creates `O(log(n))` temporary objects that will be collected as garbage. - public func delete(rbSet : Set, value : T) : Set = mapOps.delete<()>(rbSet, value); + public func delete(rbSet : Set, value : T) : Set = Internal.delete(rbSet, compare, value); /// Test if a set contains a given element. /// @@ -137,7 +155,7 @@ module { /// assuming that the `compare` function implements an `O(1)` comparison. /// /// Note: Creates `O(log(n))` temporary objects that will be collected as garbage. - public func contains(rbSet : Set, value : T) : Bool = Option.isSome(mapOps.get(rbSet, value)); + public func contains(rbSet : Set, value : T) : Bool = Internal.contains(rbSet, compare, value); /// [Set union](https://en.wikipedia.org/wiki/Union_(set_theory)) operation. /// @@ -163,9 +181,9 @@ module { switch (rbSet1, rbSet2) { case (#leaf, rbSet) { rbSet }; case (rbSet, #leaf) { rbSet }; - case (#node (_,l1, (k, v), r1), _) { - let (l2, _, r2) = Map.Internal.split(k, rbSet2, compare); - Map.Internal.join(union(l1, l2), (k, v), union(r1, r2)) + case (#node (_, l1, x, r1), _) { + let (l2, _, r2) = Internal.split(x, rbSet2, compare); + Internal.join(union(l1, l2), x, union(r1, r2)) }; }; }; @@ -194,12 +212,12 @@ module { switch (rbSet1, rbSet2) { case (#leaf, _) { #leaf }; case (_, #leaf) { #leaf }; - case (#node (_, l1, (k, v), r1), _) { - let (l2, b2, r2) = Map.Internal.split(k, rbSet2, compare); + case (#node (_, l1, x, r1), _) { + let (l2, b2, r2) = Internal.split(x, rbSet2, compare); let l = intersect(l1, l2); let r = intersect(r1, r2); - if b2 { Map.Internal.join (l, (k, v), r) } - else { Map.Internal.join2(l, r) }; + if b2 { Internal.join (l, x, r) } + else { Internal.join2(l, r) }; }; }; }; @@ -228,9 +246,9 @@ module { switch (rbSet1, rbSet2) { case (#leaf, _) { #leaf }; case (rbSet, #leaf) { rbSet }; - case (_, (#node(_, l2, (k, _), r2))) { - let (l1, _, r1) = Map.Internal.split(k, rbSet1, compare); - Map.Internal.join2(diff(l1, l2), diff(r1, r2)); + case (_, (#node(_, l2, x, r2))) { + let (l1, _, r1) = Internal.split(x, rbSet1, compare); + Internal.join2(diff(l1, l2), diff(r1, r2)); } } }; @@ -242,7 +260,7 @@ module { /// /// Example: /// ```motoko - /// import Map "mo:base/PersistentOrderedMap"; + /// import Set "mo:base/PersistentOrderedSet"; /// import Nat "mo:base/Nat" /// import Iter "mo:base/Iter" /// @@ -270,7 +288,7 @@ module { /// /// Example: /// ```motoko - /// import Map "mo:base/PersistentOrderedMap"; + /// import Set "mo:base/PersistentOrderedSet"; /// import Nat "mo:base/Nat" /// import Iter "mo:base/Iter"; /// @@ -376,6 +394,62 @@ module { }; }; + type IterRep = List.List<{ #tr : Set; #x : T }>; + + public type Direction = { #fwd; #bwd }; + + /// Get an iterator for the elements of the `rbSet`, in ascending (`#fwd`) or descending (`#bwd`) order as specified by `direction`. + /// The iterator takes a snapshot view of the set and is not affected by concurrent modifications. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/PersistentOrderedSet"; + /// import Nat "mo:base/Nat" + /// import Iter "mo:base/Iter" + /// + /// let setOps = Set.SetOps(Nat.compare); + /// let rbSet = setOps.fromIter(Iter.fromArray([(0, 2, 1)])); + /// + /// Debug.print(debug_show(Iter.toArray(Set.iter(rbSet, #fwd)))); + /// Debug.print(debug_show(Iter.toArray(Map.iter(rbSet, #bwd)))); + /// + /// // [0, 1, 2] + /// // [2, 1, 0] + /// ``` + /// + /// Cost of iteration over all elements: + /// Runtime: `O(n)`. + /// Space: `O(log(n))` retained memory plus garbage, see the note below. + /// where `n` denotes the number of elements stored in the set. + /// + /// Note: Full map iteration creates `O(n)` temporary objects that will be collected as garbage. + public func iter(rbSet : Set, direction : Direction) : I.Iter { + object { + var trees : IterRep = ?(#tr(rbSet), null); + public func next() : ?T { + switch (direction, trees) { + case (_, null) { null }; + case (_, ?(#tr(#leaf), ts)) { + trees := ts; + next() + }; + case (_, ?(#x(x), ts)) { + trees := ts; + ?x + }; // TODO: Let's float-out case on direction + case (#fwd, ?(#tr(#node(_, l, x, r)), ts)) { + trees := ?(#tr(l), ?(#x(x), ?(#tr(r), ts))); + next() + }; + case (#bwd, ?(#tr(#node(_, l, x, r)), ts)) { + trees := ?(#tr(r), ?(#x(x), ?(#tr(l), ts))); + next() + } + } + } + } + }; + /// Returns an Iterator (`Iter`) over the elements of the set. /// Iterator provides a single method `next()`, which returns /// elements in ascending order, or `null` when out of elements to iterate over. @@ -399,7 +473,7 @@ module { /// where `n` denotes the number of elements stored in the set. /// /// Note: Full set iteration creates `O(n)` temporary objects that will be collected as garbage. - public func elements(s : Set) : I.Iter = Map.keys(s); + public func elements(s : Set) : I.Iter = iter(s, #fwd); /// Create a new empty Set. /// @@ -417,7 +491,7 @@ module { /// Cost of empty set creation /// Runtime: `O(1)`. /// Space: `O(1)` - public func empty() : Set = Map.empty(); + public func empty() : Set = #leaf; /// Determine the size of the tree as the number of elements. /// @@ -430,7 +504,7 @@ module { /// let setOps = Set.SetOps(Nat.compare); /// let rbSet = setOps.fromIter(Iter.fromArray([0, 2, 1])); /// - /// Debug.print(debug_show(Map.size(rbSet))); + /// Debug.print(debug_show(Set.size(rbSet))); /// /// // 3 /// ``` @@ -438,7 +512,14 @@ module { /// Runtime: `O(n)`. /// Space: `O(1)`. /// where `n` denotes the number of elements stored in the tree. - public func size(rbSet : Set) : Nat = Map.size(rbSet); + public func size(t : Set) : Nat { + switch t { + case (#leaf) { 0 }; + case (#node(_, l, _, r)) { + size(l) + size(r) + 1 + } + } + }; /// Collapses the elements in `rbSet` into a single value by starting with `base` /// and progessively combining elements into `base` with `combine`. Iteration runs @@ -466,17 +547,17 @@ module { /// where `n` denotes the number of elements stored in the set. /// /// Note: Full set iteration creates `O(n)` temporary objects that will be collected as garbage. - public func foldLeft ( + public func foldLeft( rbSet : Set, base : Accum, combine : (T, Accum) -> Accum ) : Accum { - Map.foldLeft( - rbSet, - base, - func (x : T , _ : (), acc : Accum) : Accum { combine(x, acc) } - ) + var acc = base; + for(val in iter(rbSet, #fwd)){ + acc := combine(val, acc); + }; + acc }; /// Collapses the elements in `rbSet` into a single value by starting with `base` @@ -505,17 +586,17 @@ module { /// where `n` denotes the number of elements stored in the set. /// /// Note: Full set iteration creates `O(n)` temporary objects that will be collected as garbage. - public func foldRight ( + public func foldRight( rbSet : Set, base : Accum, combine : (T, Accum) -> Accum ) : Accum { - Map.foldRight( - rbSet, - base, - func (x : T , _ : (), acc : Accum) : Accum { combine(x, acc) } - ) + var acc = base; + for(val in iter(rbSet, #bwd)){ + acc := combine(val, acc); + }; + acc }; /// Test if set is empty. @@ -539,4 +620,349 @@ module { case _ { false }; }; }; + + module Internal { + public func contains(t : Set, compare : (T, T) -> O.Order, x : T) : Bool { + switch t { + case (#leaf) { false }; + case (#node(_c, l, x1, r)) { + switch (compare(x, x1)) { + case (#less) { contains(l, compare, x) }; + case (#equal) { true }; + case (#greater) { contains(r, compare, x) } + } + } + } + }; + + func redden(t : Set) : Set { + switch t { + case (#node (#B, l, x, r)) { + (#node (#R, l, x, r)) + }; + case _ { + Debug.trap "RBTree.red" + } + } + }; + + func lbalance(left : Set, x : T, right : Set) : Set { + switch (left, right) { + case (#node(#R, #node(#R, l1, x1, r1), x2, r2), r) { + #node( + #R, + #node(#B, l1, x1, r1), + x2, + #node(#B, r2, x, r)) + }; + case (#node(#R, l1, x1, #node(#R, l2, x2, r2)), r) { + #node( + #R, + #node(#B, l1, x1, l2), + x2, + #node(#B, r2, x, r)) + }; + case _ { + #node(#B, left, x, right) + } + } + }; + + func rbalance(left : Set, x : T, right : Set) : Set { + switch (left, right) { + case (l, #node(#R, l1, x1, #node(#R, l2, x2, r2))) { + #node( + #R, + #node(#B, l, x, l1), + x1, + #node(#B, l2, x2, r2)) + }; + case (l, #node(#R, #node(#R, l1, x1, r1), x2, r2)) { + #node( + #R, + #node(#B, l, x, l1), + x1, + #node(#B, r1, x2, r2)) + }; + case _ { + #node(#B, left, x, right) + }; + } + }; + + public func put ( + s : Set, + compare : (T, T) -> O.Order, + elem : T, + ) + : Set{ + func ins(tree : Set) : Set { + switch tree { + case (#leaf) { + #node(#R, #leaf, elem, #leaf) + }; + case (#node(#B, left, x, right)) { + switch (compare (elem, x)) { + case (#less) { + lbalance(ins left, x, right) + }; + case (#greater) { + rbalance(left, x, ins right) + }; + case (#equal) { + #node(#B, left, x, right) + } + } + }; + case (#node(#R, left, x, right)) { + switch (compare (elem, x)) { + case (#less) { + #node(#R, ins left, x, right) + }; + case (#greater) { + #node(#R, left, x, ins right) + }; + case (#equal) { + #node(#R, left, x, right) + } + } + } + }; + }; + switch (ins s) { + case (#node(#R, left, x, right)) { + #node(#B, left, x, right); + }; + case other { other }; + }; + }; + + func balLeft(left : Set, x : T, right : Set) : Set { + switch (left, right) { + case (#node(#R, l1, x1, r1), r) { + #node(#R, #node(#B, l1, x1, r1), x, r) + }; + case (_, #node(#B, l2, x2, r2)) { + rbalance(left, x, #node(#R, l2, x2, r2)) + }; + case (_, #node(#R, #node(#B, l2, x2, r2), x3, r3)) { + #node(#R, + #node(#B, left, x, l2), + x2, + rbalance(r2, x3, redden r3)) + }; + case _ { Debug.trap "balLeft" }; + } + }; + + func balRight(left : Set, x : T, right : Set) : Set { + switch (left, right) { + case (l, #node(#R, l1, x1, r1)) { + #node(#R, l, x, #node(#B, l1, x1, r1)) + }; + case (#node(#B, l1, x1, r1), r) { + lbalance(#node(#R, l1, x1, r1), x, r); + }; + case (#node(#R, l1, x1, #node(#B, l2, x2, r2)), r3) { + #node(#R, + lbalance(redden l1, x1, l2), + x2, + #node(#B, r2, x, r3)) + }; + case _ { Debug.trap "balRight" }; + } + }; + + func append(left : Set, right: Set) : Set { + switch (left, right) { + case (#leaf, _) { right }; + case (_, #leaf) { left }; + case (#node (#R, l1, x1, r1), + #node (#R, l2, x2, r2)) { + switch (append (r1, l2)) { + case (#node (#R, l3, x3, r3)) { + #node( + #R, + #node(#R, l1, x1, l3), + x3, + #node(#R, r3, x2, r2)) + }; + case r1l2 { + #node(#R, l1, x1, #node(#R, r1l2, x2, r2)) + } + } + }; + case (t1, #node(#R, l2, x2, r2)) { + #node(#R, append(t1, l2), x2, r2) + }; + case (#node(#R, l1, x1, r1), t2) { + #node(#R, l1, x1, append(r1, t2)) + }; + case (#node(#B, l1, x1, r1), #node (#B, l2, x2, r2)) { + switch (append (r1, l2)) { + case (#node (#R, l3, x3, r3)) { + #node(#R, + #node(#B, l1, x1, l3), + x3, + #node(#B, r3, x2, r2)) + }; + case r1l2 { + balLeft ( + l1, + x1, + #node(#B, r1l2, x2, r2) + ) + } + } + } + } + }; + + public func delete(tree : Set, compare : (T, T) -> O.Order, x : T) : Set { + func delNode(left : Set, x1 : T, right : Set) : Set { + switch (compare (x, x1)) { + case (#less) { + let newLeft = del left; + switch left { + case (#node(#B, _, _, _)) { + balLeft(newLeft, x1, right) + }; + case _ { + #node(#R, newLeft, x1, right) + } + } + }; + case (#greater) { + let newRight = del right; + switch right { + case (#node(#B, _, _, _)) { + balRight(left, x1, newRight) + }; + case _ { + #node(#R, left, x1, newRight) + } + } + }; + case (#equal) { + append(left, right) + }; + } + }; + func del(tree : Set) : Set { + switch tree { + case (#leaf) { + tree + }; + case (#node(_, left, x1, right)) { + delNode(left, x1, right) + } + }; + }; + switch (del(tree)) { + case (#node(#R, left, x1, right)) { + #node(#B, left, x1, right); + }; + case other { other }; + }; + }; + + // TODO: Instead, consider storing the black height in the node constructor + public func blackHeight (rbSet : Set) : Nat { + func f (node : Set, acc : Nat) : Nat { + switch node { + case (#leaf) { acc }; + case (#node (#R, l1, _, _)) { f(l1, acc) }; + case (#node (#B, l1, _, _)) { f(l1, acc + 1) } + } + }; + f (rbSet, 0) + }; + + public func joinL(l : Set, x : T, r : Set) : Set { + if (blackHeight r <= blackHeight l) { (#node (#R, l, x, r)) } + else { + switch r { + case (#node (#R, rl, rx, rr)) { (#node (#R, joinL(l, x, rl) , rx, rr)) }; + case (#node (#B, rl, rx, rr)) { balLeft (joinL(l, x, rl), rx, rr) }; + case _ { Debug.trap "joinL" }; + } + } + }; + + public func joinR(l : Set, x : T, r : Set) : Set { + if (blackHeight l <= blackHeight r) { (#node (#R, l, x, r)) } + else { + switch l { + case (#node (#R, ll, lx, lr)) { (#node (#R, ll , lx, joinR (lr, x, r))) }; + case (#node (#B, ll, lx, lr)) { balRight (ll, lx, joinR (lr, x, r)) }; + case _ { Debug.trap "joinR" }; + } + } + }; + + public func paint(color : Color, rbMap : Set) : Set { + switch rbMap { + case (#leaf) { #leaf }; + case (#node (_, l, x, r)) { (#node (color, l, x, r)) }; + } + }; + + public func splitMin (rbSet : Set) : (T, Set) { + switch rbSet { + case (#leaf) { Debug.trap "splitMin" }; + case (#node(_, #leaf, x, r)) { (x, r) }; + case (#node(_, l, x, r)) { + let (m, l2) = splitMin l; + (m, join(l2, x, r)) + }; + } + }; + + // Joins an element and two trees. + // See Tobias Nipkow's "Functional Data Structures and Algorithms", 117 + public func join(l : Set, x : T, r : Set) : Set { + if (Internal.blackHeight r < Internal.blackHeight l) { + return Internal.paint(#B, Internal.joinR(l, x, r)) + }; + if (Internal.blackHeight l < Internal.blackHeight r) { + return Internal.paint(#B, Internal.joinL(l, x, r)) + }; + return (#node (#B, l, x, r)) + }; + + // Joins two trees. + // See Tobias Nipkow's "Functional Data Structures and Algorithms", 117 + public func join2(l : Set, r : Set) : Set { + switch r { + case (#leaf) { l }; + case _ { + let (m, r2) = Internal.splitMin r; + join(l, m, r2) + }; + } + }; + + // Splits `rbSet` with respect to a given element `x`, into tuple `(l, b, r)` + // such that `l` contains the elements less than `x`, `r` contains the elements greater than `x` + // and `b` is `true` if `x` was in the `rbSet`. + // See Tobias Nipkow's "Functional Data Structures and Algorithms", 117 + public func split(x : T, rbSet : Set, compare : (T, T) -> O.Order) : (Set, Bool, Set) { + switch rbSet { + case (#leaf) { (#leaf, false, #leaf)}; + case (#node (_, l, x1, r)) { + switch (compare(x, x1)) { + case (#less) { + let (l1, b, l2) = split(x, l, compare); + (l1, b, join(l2, x1, r)) + }; + case (#equal) { (l, true, r) }; + case (#greater) { + let (r1, b, r2) = split(x, r, compare); + (join(l, x1, r1), b, r2) + }; + }; + }; + }; + }; + } } diff --git a/test/PersistentOrderedSet.test.mo b/test/PersistentOrderedSet.test.mo index 3b37b219..756dc2d2 100644 --- a/test/PersistentOrderedSet.test.mo +++ b/test/PersistentOrderedSet.test.mo @@ -32,9 +32,9 @@ func checkSet(rbSet : Set.Set) { func blackDepth(node : Set.Set) : Nat { switch node { case (#leaf) 0; - case (#node(color, left, (key, _), right)) { - checkKey(left, func(x) { x < key }); - checkKey(right, func(x) { x > key }); + case (#node(color, left, x1, right)) { + checkElem(left, func(x) { x < x1 }); + checkElem(right, func(x) { x > x1 }); let leftBlacks = blackDepth(left); let rightBlacks = blackDepth(right); assert (leftBlacks == rightBlacks); @@ -60,11 +60,11 @@ func isRed(node : Set.Set) : Bool { } }; -func checkKey(node : Set.Set, isValid : Nat -> Bool) { +func checkElem(node : Set.Set, isValid : Nat -> Bool) { switch node { case (#leaf) {}; - case (#node(_, _, (key, _), _)) { - assert (isValid(key)) + case (#node(_, _, elem, _)) { + assert (isValid(elem)) } } };