From 1961fab3ba2d1459ea6bae29c1c997328e03328f Mon Sep 17 00:00:00 2001 From: Pavel Golovin Date: Wed, 13 Nov 2024 17:36:49 +0100 Subject: [PATCH] [Milestone-3] Serokell: New OrderedSet.mo & OrderedMap fixup (#662) This is an MR for the 3rd Milestone of the Serokell's grant about improving Motoko's base library. The main goal of the PR is to introduce a new functional implementation of the set data structure to the' base' library. Also, it brings a few changes to the new functional map that was added in #664 , #654 . # General changes: * rename `PersistentOrderedMap` to `OrderedMap` (same for the `OrderedSet`) * improve docs # Functional Map changes: ## New functionality: + add `any`/`all` functions + add `contains` function + add `minEntry`/`maxEntry` ## Optimizations: + Store `size` in the Map, [benchmark results](https://github.com/serokell/motoko-base/pull/35) ## Fixup: + add `entriesRev()`, remove `iter()` # NEW functional Set: The new data structure implements an ordered set interface using Red-Black trees as well as the new functional map from the 1-2 Milestones. ## API implemented: * Basic operations (based on the map): `put`, `delete`, `contains`, `fromIter`, etc * Maps and folds: `map`, `mapFilter`, `foldLeft`, `foldRight` * Set operations: `union` , `intersect`, `diff`, `isSubset`, `equal` * Additional operations (as for the `OrderedMap`): `min`/`max`, `all`/`some` ## Maintainance support: * Unit, property tests * Documentation ## Applied optimizations: * Same optimizations that were useful for the functional map: * inline node color * float-out exceeded matching in iteration * `map`/`filterMap` through `foldLeft` * direct recursion in `foldLeft` * [Benchmark results for all four optimizations together](https://github.com/serokell/motoko-base/pull/27) * store size in the root of the tree, [benchmark results](https://github.com/serokell/motoko-base/pull/36#issuecomment-2455860708) * Pattern matching order optimization, [benchmark results](https://github.com/serokell/motoko-base/pull/36#issuecomment-2455998376) * Other optimizations: * Inline code of `OrderedMap` instead of sharing it, [benchmark results](https://github.com/serokell/motoko-base/pull/25) * `intersect` optimization: use order of output values to build the resulting tree faster, see https://github.com/serokell/motoko-base/pull/39 * `isSubset`, `equal` optimization: use early exit and use order of subtrees to reduce intermediate tree height, see https://github.com/serokell/motoko-base/pull/37 ## Rejected optimizations: * Nipkow's implementation of set operation [Tobias Nipkow's "Functional Data Structures and Algorithms", 117]. Initially, we were planning to use an implementation of set operations (`intersect`, `union`, `diff`) from Nipkow's book. However, the experiment shows that naive implementation with a simple size heuristic performs better. [The benchmark results](https://github.com/serokell/motoko-base/pull/33) are comparing 3 versions: * persistentset_baseline -- original implementation that uses Nipkow's algorithms. However, the black height is calculated before each set operation (the book assumes it's stored). * persistentset_bh -- the same as the baseline but the black height is stored in each node. * persistentset -- naive implementation that looks up in a smaller set and modifies a bigger one (it gives us `O(min(n,m)log((max(n,m))` which is very close to Nipkow's version). Sizes of sets are also stored but only in the root. The last one outperforms others and keeps a tree slim in terms of byte size. Thus, we have picked it. ## Final benchmark results: ### Collection benchmarks | |binary_size|generate|max mem|batch_get 50|batch_put 50|batch_remove 50|upgrade| |--:|--:|--:|--:|--:|--:|--:|--:| |orderedset+100|218_168|186_441|37_916|53_044|121_237|127_460|346_108| |trieset+100|211_245|574_022|47_652|131_218|288_429|268_499|729_696| |orderedset+1000|218_168|2_561_296|520_364|69_883|158_349|170_418|3_186_579| |trieset+1000|211_245|7_374_045|633_440|162_806|383_594|375_264|9_178_466| |orderedset+10000|218_168|40_015_301|320_532|84_660|192_931|215_592|31_522_120| |trieset+10000|211_245|105_695_670|682_792|192_931|457_923|462_594|129_453_045| |orderedset+100000|218_168|476_278_087|3_200_532|98_553|230_123|258_372|409_032_232| |trieset+100000|211_245|1_234_038_235|6_826_516|222_247|560_440|549_813|1_525_692_388| |orderedset+1000000|218_168|5_514_198_432|32_000_532|115_836|268_236|306_896|4_090_302_778| |trieset+1000000|211_245|13_990_048_548|68_228_312|252_211|650_405|642_099|17_455_845_492| ### set API | |size|intersect|union|diff|equals|isSubset| |--:|--:|--:|--:|--:|--:|--:| |orderedset+100|100|146_264|157_544|215_871|28_117|27_726| |trieset+100|100|352_496|411_306|350_935|201_896|201_456| |orderedset+1000|1000|162_428|194_198|286_747|242_329|241_938| |trieset+1000|1000|731_650|1_079_906|912_629|2_589_090|4_023_673| |orderedset+10000|10000|177_080|231_070|345_529|2_383_587|2_383_591| |trieset+10000|10000|3_986_854|21_412_306|5_984_106|46_174_710|31_885_381| |orderedset+100000|100000|190_727|267_008|402_081|91_300_348|91_300_393| |trieset+100000|100000|178_863_894|209_889_623|199_028_396|521_399_350|521_399_346| |orderedset+1000000|1000000|205_022|304_937|464_859|912_901_595|912_901_558| |trieset+1000000|1000000|1_782_977_198|2_092_850_787|1_984_818_266|5_813_335_155|5_813_335_151| ### new set API | |size|foldLeft|foldRight|mapfilter|map| |--:|--:|--:|--:|--:|--:| |orderedset|100|16_487|16_463|88_028|224_597| |orderedset|1000|133_685|131_953|1_526_510|4_035_782| |orderedset|10000|1_305_120|1_287_495|28_455_361|51_527_733| |orderedset|100000|13_041_665|12_849_418|344_132_505|630_692_463| |orderedset|1000000|130_428_573|803_454_777|4_019_592_041|7_453_944_902| --------- Co-authored-by: Andrei Borzenkov Co-authored-by: Andrei Borzenkov Co-authored-by: Sergey Gulin Co-authored-by: Claudio Russo --- src/OrderedMap.mo | 1225 +++++++++++++++++++++++++++++++++ src/OrderedSet.mo | 1226 ++++++++++++++++++++++++++++++++++ test/OrderedMap.prop.test.mo | 277 ++++++++ test/OrderedMap.test.mo | 574 ++++++++++++++++ test/OrderedSet.prop.test.mo | 334 +++++++++ test/OrderedSet.test.mo | 600 +++++++++++++++++ 6 files changed, 4236 insertions(+) create mode 100644 src/OrderedMap.mo create mode 100644 src/OrderedSet.mo create mode 100644 test/OrderedMap.prop.test.mo create mode 100644 test/OrderedMap.test.mo create mode 100644 test/OrderedSet.prop.test.mo create mode 100644 test/OrderedSet.test.mo diff --git a/src/OrderedMap.mo b/src/OrderedMap.mo new file mode 100644 index 00000000..091c44eb --- /dev/null +++ b/src/OrderedMap.mo @@ -0,0 +1,1225 @@ +/// Stable key-value map implemented as a red-black tree with nodes storing key-value pairs. +/// +/// A red-black tree is a balanced binary search tree ordered by the keys. +/// +/// The tree data structure internally colors each of its nodes either red or black, +/// and uses this information to balance the tree during the modifying operations. +/// +/// Performance: +/// * Runtime: `O(log(n))` worst case cost per insertion, removal, and retrieval operation. +/// * Space: `O(n)` for storing the entire tree. +/// `n` denotes the number of key-value entries (i.e. nodes) stored in the tree. +/// +/// Note: +/// * Map operations, such as retrieval, insertion, and removal create `O(log(n))` temporary objects that become garbage. +/// +/// Credits: +/// +/// The core of this implementation is derived from: +/// +/// * Ken Friis Larsen's [RedBlackMap.sml](https://github.com/kfl/mosml/blob/master/src/mosmllib/Redblackmap.sml), which itself is based on: +/// * Stefan Kahrs, "Red-black trees with types", Journal of Functional Programming, 11(4): 425-432 (2001), [version 1 in web appendix](http://www.cs.ukc.ac.uk/people/staff/smk/redblack/rb.html). + +import Debug "Debug"; +import I "Iter"; +import List "List"; +import Nat "Nat"; +import O "Order"; + +module { + /// Collection of key-value entries, ordered by the keys and key unique. + /// The keys have the generic type `K` and the values the generic type `V`. + /// If `K` and `V` is stable types then `Map` is also stable. + /// To ensure that property the `Map` does not have any methods, instead + /// they are gathered in the functor-like class `Operations` (see example there). + public type Map = { + size : Nat; + root : Tree + }; + + // Note: Leaves are considered implicitly black. + type Tree = { + #red : (Tree, K, V, Tree); + #black : (Tree, K, V, Tree); + #leaf + }; + + /// Class that captures key type `K` along with its ordering function `compare` + /// and provides all operations to work with a map of type `Map`. + /// + /// An instance object should be created once as a canister field to ensure + /// that the same ordering function is used for every operation. + /// + /// Example: + /// ```motoko + /// import Map "mo:base/OrderedMap"; + /// import Nat "mo:base/Nat"; + /// + /// actor { + /// let natMap = Map.Make(Nat.compare); // : Operations + /// stable var keyStorage : Map.Map = natMap.empty(); + /// + /// public func addKey(id : Nat, key : Text) : async () { + /// keyStorage := natMap.put(keyStorage, id, key); + /// } + /// } + /// ``` + public class Operations(compare : (K, K) -> O.Order) { + + /// Returns a new map, containing all entries given by the iterator `i`. + /// If there are multiple entries with the same key the last one is taken. + /// + /// Example: + /// ```motoko + /// import Map "mo:base/OrderedMap"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natMap = Map.Make(Nat.compare); + /// let m = natMap.fromIter(Iter.fromArray([(0, "Zero"), (2, "Two"), (1, "One")])); + /// + /// Debug.print(debug_show(Iter.toArray(natMap.entries(m)))); + /// + /// // [(0, "Zero"), (1, "One"), (2, "Two")] + /// ``` + /// + /// Runtime: `O(n * log(n))`. + /// Space: `O(n)` retained memory plus garbage, see the note below. + /// where `n` denotes the number of key-value entries stored in the map and + /// assuming that the `compare` function implements an `O(1)` comparison. + /// + /// Note: Creates `O(n * log(n))` temporary objects that will be collected as garbage. + public func fromIter(i : I.Iter<(K, V)>) : Map + = Internal.fromIter(i, compare); + + /// Insert the value `value` with key `key` into the map `m`. Overwrites any existing entry with key `key`. + /// Returns a modified map. + /// + /// Example: + /// ```motoko + /// import Map "mo:base/OrderedMap"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natMap = Map.Make(Nat.compare); + /// var map = natMap.empty(); + /// + /// map := natMap.put(map, 0, "Zero"); + /// map := natMap.put(map, 2, "Two"); + /// map := natMap.put(map, 1, "One"); + /// + /// Debug.print(debug_show(Iter.toArray(natMap.entries(map)))); + /// + /// // [(0, "Zero"), (1, "One"), (2, "Two")] + /// ``` + /// + /// Runtime: `O(log(n))`. + /// Space: `O(log(n))`. + /// where `n` denotes the number of key-value entries stored in the map and + /// assuming that the `compare` function implements an `O(1)` comparison. + /// + /// Note: The returned map shares with the `m` most of the tree nodes. + /// Garbage collecting one of maps (e.g. after an assignment `m := natMap.put(m, k)`) + /// causes collecting `O(log(n))` nodes. + public func put(m : Map, key : K, value : V) : Map + = replace(m, key, value).0; + + /// Insert the value `value` with key `key` into the map `m`. Returns modified map and + /// the previous value associated with key `key` or `null` if no such value exists. + /// + /// Example: + /// ```motoko + /// import Map "mo:base/OrderedMap"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natMap = Map.Make(Nat.compare); + /// let map0 = natMap.fromIter(Iter.fromArray([(0, "Zero"), (2, "Two"), (1, "One")])); + /// + /// let (map1, old1) = natMap.replace(map0, 0, "Nil"); + /// + /// Debug.print(debug_show(Iter.toArray(natMap.entries(map1)))); + /// Debug.print(debug_show(old1)); + /// // [(0, "Nil"), (1, "One"), (2, "Two")] + /// // ?"Zero" + /// + /// let (map2, old2) = natMap.replace(map0, 3, "Three"); + /// + /// Debug.print(debug_show(Iter.toArray(natMap.entries(map2)))); + /// Debug.print(debug_show(old2)); + /// // [(0, "Zero"), (1, "One"), (2, "Two"), (3, "Three")] + /// // null + /// ``` + /// + /// Runtime: `O(log(n))`. + /// Space: `O(log(n))` retained memory plus garbage, see the note below. + /// where `n` denotes the number of key-value entries stored in the map and + /// assuming that the `compare` function implements an `O(1)` comparison. + /// + /// Note: The returned map shares with the `m` most of the tree nodes. + /// Garbage collecting one of maps (e.g. after an assignment `m := natMap.replace(m, k).0`) + /// causes collecting `O(log(n))` nodes. + public func replace(m : Map, key : K, value : V) : (Map, ?V) { + switch (Internal.replace(m.root, compare, key, value)) { + case (t, null) { ({root = t; size = m.size + 1}, null) }; + case (t, v) { ({root = t; size = m.size}, v)} + } + }; + + /// Creates a new map by applying `f` to each entry in the map `m`. For each entry + /// `(k, v)` in the old map, if `f` evaluates to `null`, the entry is discarded. + /// Otherwise, the entry is transformed into a new entry `(k, v2)`, where + /// the new value `v2` is the result of applying `f` to `(k, v)`. + /// + /// Example: + /// ```motoko + /// import Map "mo:base/OrderedMap"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natMap = Map.Make(Nat.compare); + /// let map = natMap.fromIter(Iter.fromArray([(0, "Zero"), (2, "Two"), (1, "One")])); + /// + /// func f(key : Nat, val : Text) : ?Text { + /// if(key == 0) {null} + /// else { ?("Twenty " # val)} + /// }; + /// + /// let newMap = natMap.mapFilter(map, f); + /// + /// Debug.print(debug_show(Iter.toArray(natMap.entries(newMap)))); + /// + /// // [(1, "Twenty One"), (2, "Twenty Two")] + /// ``` + /// + /// Runtime: `O(n * log(n))`. + /// Space: `O(n)` retained memory plus garbage, see the note below. + /// where `n` denotes the number of key-value entries stored in the map and + /// assuming that the `compare` function implements an `O(1)` comparison. + /// + /// Note: Creates `O(n * log(n))` temporary objects that will be collected as garbage. + public func mapFilter(m : Map, f : (K, V1) -> ?V2) : Map + = Internal.mapFilter(m, compare, f); + + /// Get the value associated with key `key` in the given map `m` if present and `null` otherwise. + /// + /// Example: + /// ```motoko + /// import Map "mo:base/OrderedMap"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natMap = Map.Make(Nat.compare); + /// let map = natMap.fromIter(Iter.fromArray([(0, "Zero"), (2, "Two"), (1, "One")])); + /// + /// Debug.print(debug_show(natMap.get(map, 1))); + /// Debug.print(debug_show(natMap.get(map, 42))); + /// + /// // ?"One" + /// // null + /// ``` + /// + /// Runtime: `O(log(n))`. + /// Space: `O(1)`. + /// where `n` denotes the number of key-value entries stored in the map and + /// assuming that the `compare` function implements an `O(1)` comparison. + public func get(m : Map, key : K) : ?V + = Internal.get(m.root, compare, key); + + /// Test whether the map `m` contains any binding for the given `key`. + /// + /// Example: + /// ```motoko + /// import Map "mo:base/OrderedMap"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natMap = Map.Make(Nat.compare); + /// let map = natMap.fromIter(Iter.fromArray([(0, "Zero"), (2, "Two"), (1, "One")])); + /// + /// Debug.print(debug_show natMap.contains(map, 1)); // => true + /// Debug.print(debug_show natMap.contains(map, 42)); // => false + /// ``` + /// + /// Runtime: `O(log(n))`. + /// Space: `O(1)`. + /// where `n` denotes the number of key-value entries stored in the map and + /// assuming that the `compare` function implements an `O(1)` comparison. + public func contains(m: Map, key: K) : Bool + = Internal.contains(m.root, compare, key); + + /// Retrieves a key-value pair from the map `m` with a maximal key. If the map is empty returns `null`. + /// + /// Example: + /// ```motoko + /// import Map "mo:base/OrderedMap"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natMap = Map.Make(Nat.compare); + /// let map = natMap.fromIter(Iter.fromArray([(0, "Zero"), (2, "Two"), (1, "One")])); + /// + /// Debug.print(debug_show(natMap.maxEntry(map))); // => ?(2, "Two") + /// Debug.print(debug_show(natMap.maxEntry(natMap.empty()))); // => null + /// ``` + /// + /// Runtime: `O(log(n))`. + /// Space: `O(1)`. + /// where `n` denotes the number of key-value entries stored in the map. + public func maxEntry(m: Map) : ?(K, V) + = Internal.maxEntry(m.root); + + /// Retrieves a key-value pair from the map `m` with a minimal key. If the map is empty returns `null`. + /// + /// Example: + /// ```motoko + /// import Map "mo:base/OrderedMap"; + /// import Iter "mo:base/Iter"; + /// import Nat "mo:base/Nat"; + /// import Debug "mo:base/Debug"; + /// + /// let natMap = Map.Make(Nat.compare); + /// let map = natMap.fromIter(Iter.fromArray([(0, "Zero"), (2, "Two"), (1, "One")])); + /// + /// Debug.print(debug_show(natMap.minEntry(map))); // => ?(0, "Zero") + /// Debug.print(debug_show(natMap.minEntry(natMap.empty()))); // => null + /// ``` + /// + /// Runtime: `O(log(n))`. + /// Space: `O(1)`. + /// where `n` denotes the number of key-value entries stored in the map. + public func minEntry(m : Map) : ?(K, V) + = Internal.minEntry(m.root); + + /// Deletes the entry with the key `key` from the map `m`. Has no effect if `key` is not + /// present in the map. Returns modified map. + /// + /// Example: + /// ```motoko + /// import Map "mo:base/OrderedMap"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natMap = Map.Make(Nat.compare); + /// let map = natMap.fromIter(Iter.fromArray([(0, "Zero"), (2, "Two"), (1, "One")])); + /// + /// Debug.print(debug_show(Iter.toArray(natMap.entries(natMap.delete(map, 1))))); + /// Debug.print(debug_show(Iter.toArray(natMap.entries(natMap.delete(map, 42))))); + /// + /// // [(0, "Zero"), (2, "Two")] + /// // [(0, "Zero"), (1, "One"), (2, "Two")] + /// ``` + /// + /// Runtime: `O(log(n))`. + /// Space: `O(log(n))` + /// where `n` denotes the number of key-value entries stored in the map and + /// assuming that the `compare` function implements an `O(1)` comparison. + /// + /// Note: The returned map shares with the `m` most of the tree nodes. + /// Garbage collecting one of maps (e.g. after an assignment `m := natMap.delete(m, k).0`) + /// causes collecting `O(log(n))` nodes. + public func delete(m : Map, key : K) : Map + = remove(m, key).0; + + /// Deletes the entry with the key `key`. Returns modified map and the + /// previous value associated with key `key` or `null` if no such value exists. + /// + /// Example: + /// ```motoko + /// import Map "mo:base/OrderedMap"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natMap = Map.Make(Nat.compare); + /// let map0 = natMap.fromIter(Iter.fromArray([(0, "Zero"), (2, "Two"), (1, "One")])); + /// + /// let (map1, old1) = natMap.remove(map0, 0); + /// + /// Debug.print(debug_show(Iter.toArray(natMap.entries(map1)))); + /// Debug.print(debug_show(old1)); + /// // [(1, "One"), (2, "Two")] + /// // ?"Zero" + /// + /// let (map2, old2) = natMap.remove(map0, 42); + /// + /// Debug.print(debug_show(Iter.toArray(natMap.entries(map2)))); + /// Debug.print(debug_show(old2)); + /// // [(0, "Zero"), (1, "One"), (2, "Two")] + /// // null + /// ``` + /// + /// Runtime: `O(log(n))`. + /// Space: `O(log(n))`. + /// where `n` denotes the number of key-value entries stored in the map and + /// assuming that the `compare` function implements an `O(1)` comparison. + /// + /// Note: The returned map shares with the `m` most of the tree nodes. + /// Garbage collecting one of maps (e.g. after an assignment `m := natMap.remove(m, k)`) + /// causes collecting `O(log(n))` nodes. + public func remove(m : Map, key : K) : (Map, ?V) { + switch (Internal.remove(m.root, compare, key)) { + case (t, null) { ({root = t; size = m.size }, null) }; + case (t, v) { ({root = t; size = m.size - 1}, v) } + } + }; + + /// Create a new empty map. + /// + /// Example: + /// ```motoko + /// import Map "mo:base/OrderedMap"; + /// import Nat "mo:base/Nat"; + /// import Debug "mo:base/Debug"; + /// + /// let natMap = Map.Make(Nat.compare); + /// + /// let map = natMap.empty(); + /// + /// Debug.print(debug_show(natMap.size(map))); + /// + /// // 0 + /// ``` + /// + /// Cost of empty map creation + /// Runtime: `O(1)`. + /// Space: `O(1)` + public func empty() : Map + = Internal.empty(); + + /// Returns an Iterator (`Iter`) over the key-value pairs in the map. + /// Iterator provides a single method `next()`, which returns + /// pairs in ascending order by keys, or `null` when out of pairs to iterate over. + /// + /// Example: + /// ```motoko + /// import Map "mo:base/OrderedMap"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natMap = Map.Make(Nat.compare); + /// let map = natMap.fromIter(Iter.fromArray([(0, "Zero"), (2, "Two"), (1, "One")])); + /// + /// Debug.print(debug_show(Iter.toArray(natMap.entries(map)))); + /// // [(0, "Zero"), (1, "One"), (2, "Two")] + /// var sum = 0; + /// for ((k, _) in natMap.entries(map)) { sum += k; }; + /// Debug.print(debug_show(sum)); // => 3 + /// ``` + /// Cost of iteration over all elements: + /// Runtime: `O(n)`. + /// Space: `O(log(n))` retained memory plus garbage, see the note below. + /// where `n` denotes the number of key-value entries stored in the map. + /// + /// Note: Full map iteration creates `O(n)` temporary objects that will be collected as garbage. + public func entries(m : Map) : I.Iter<(K, V)> + = Internal.iter(m.root, #fwd); + + /// Same as `entries` but iterates in the descending order. + public func entriesRev(m : Map) : I.Iter<(K, V)> + = Internal.iter(m.root, #bwd); + + /// Returns an Iterator (`Iter`) over the keys of the map. + /// Iterator provides a single method `next()`, which returns + /// keys in ascending order, or `null` when out of keys to iterate over. + /// + /// Example: + /// ```motoko + /// import Map "mo:base/OrderedMap"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natMap = Map.Make(Nat.compare); + /// let map = natMap.fromIter(Iter.fromArray([(0, "Zero"), (2, "Two"), (1, "One")])); + /// + /// Debug.print(debug_show(Iter.toArray(natMap.keys(map)))); + /// + /// // [0, 1, 2] + /// ``` + /// Cost of iteration over all elements: + /// Runtime: `O(n)`. + /// Space: `O(log(n))` retained memory plus garbage, see the note below. + /// where `n` denotes the number of key-value entries stored in the map. + /// + /// Note: Full map iteration creates `O(n)` temporary objects that will be collected as garbage. + public func keys(m : Map) : I.Iter + = I.map(entries(m), func(kv : (K, V)) : K {kv.0}); + + + /// Returns an Iterator (`Iter`) over the values of the map. + /// Iterator provides a single method `next()`, which returns + /// values in ascending order of associated keys, or `null` when out of values to iterate over. + /// + /// Example: + /// ```motoko + /// import Map "mo:base/OrderedMap"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natMap = Map.Make(Nat.compare); + /// let map = natMap.fromIter(Iter.fromArray([(0, "Zero"), (2, "Two"), (1, "One")])); + /// + /// Debug.print(debug_show(Iter.toArray(natMap.vals(map)))); + /// + /// // ["Zero", "One", "Two"] + /// ``` + /// Cost of iteration over all elements: + /// Runtime: `O(n)`. + /// Space: `O(log(n))` retained memory plus garbage, see the note below. + /// where `n` denotes the number of key-value entries stored in the map. + /// + /// Note: Full map iteration creates `O(n)` temporary objects that will be collected as garbage. + public func vals(m : Map) : I.Iter + = I.map(entries(m), func(kv : (K, V)) : V {kv.1}); + + /// Creates a new map by applying `f` to each entry in the map `m`. Each entry + /// `(k, v)` in the old map is transformed into a new entry `(k, v2)`, where + /// the new value `v2` is created by applying `f` to `(k, v)`. + /// + /// Example: + /// ```motoko + /// import Map "mo:base/OrderedMap"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natMap = Map.Make(Nat.compare); + /// let map = natMap.fromIter(Iter.fromArray([(0, "Zero"), (2, "Two"), (1, "One")])); + /// + /// func f(key : Nat, _val : Text) : Nat = key * 2; + /// + /// let resMap = natMap.map(map, f); + /// + /// Debug.print(debug_show(Iter.toArray(natMap.entries(resMap)))); + /// // [(0, 0), (1, 2), (2, 4)] + /// ``` + /// + /// Cost of mapping all the elements: + /// Runtime: `O(n)`. + /// Space: `O(n)` retained memory + /// where `n` denotes the number of key-value entries stored in the map. + public func map(m : Map, f : (K, V1) -> V2) : Map + = Internal.map(m, f); + + /// Determine the size of the map as the number of key-value entries. + /// + /// Example: + /// ```motoko + /// import Map "mo:base/OrderedMap"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natMap = Map.Make(Nat.compare); + /// let map = natMap.fromIter(Iter.fromArray([(0, "Zero"), (2, "Two"), (1, "One")])); + /// + /// Debug.print(debug_show(natMap.size(map))); + /// // 3 + /// ``` + /// + /// Runtime: `O(n)`. + /// Space: `O(1)`. + public func size(m : Map) : Nat + = m.size; + + /// Collapses the elements in the `map` into a single value by starting with `base` + /// and progressively combining keys and values into `base` with `combine`. Iteration runs + /// left to right. + /// + /// Example: + /// ```motoko + /// import Map "mo:base/OrderedMap"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natMap = Map.Make(Nat.compare); + /// let map = natMap.fromIter(Iter.fromArray([(0, "Zero"), (2, "Two"), (1, "One")])); + /// + /// func folder(accum : (Nat, Text), key : Nat, val : Text) : ((Nat, Text)) + /// = (key + accum.0, accum.1 # val); + /// + /// Debug.print(debug_show(natMap.foldLeft(map, (0, ""), folder))); + /// + /// // (3, "ZeroOneTwo") + /// ``` + /// + /// Cost of iteration over all elements: + /// Runtime: `O(n)`. + /// Space: depends on `combine` function plus garbage, see the note below. + /// where `n` denotes the number of key-value entries stored in the map. + /// + /// Note: Full map iteration creates `O(n)` temporary objects that will be collected as garbage. + public func foldLeft( + map : Map, + base : Accum, + combine : (Accum, K, Value) -> Accum + ) : Accum + = Internal.foldLeft(map.root, base, combine); + + /// Collapses the elements in the `map` into a single value by starting with `base` + /// and progressively combining keys and values into `base` with `combine`. Iteration runs + /// right to left. + /// + /// Example: + /// ```motoko + /// import Map "mo:base/OrderedMap"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natMap = Map.Make(Nat.compare); + /// let map = natMap.fromIter(Iter.fromArray([(0, "Zero"), (2, "Two"), (1, "One")])); + /// + /// func folder(key : Nat, val : Text, accum : (Nat, Text)) : ((Nat, Text)) + /// = (key + accum.0, accum.1 # val); + /// + /// Debug.print(debug_show(natMap.foldRight(map, (0, ""), folder))); + /// + /// // (3, "TwoOneZero") + /// ``` + /// + /// Cost of iteration over all elements: + /// Runtime: `O(n)`. + /// Space: depends on `combine` function plus garbage, see the note below. + /// where `n` denotes the number of key-value entries stored in the map. + /// + /// Note: Full map iteration creates `O(n)` temporary objects that will be collected as garbage. + public func foldRight( + map : Map, + base : Accum, + combine : (K, Value, Accum) -> Accum + ) : Accum + = Internal.foldRight(map.root, base, combine); + + /// Test whether all key-value pairs satisfy a given predicate `pred`. + /// + /// Example: + /// ```motoko + /// import Map "mo:base/OrderedMap"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natMap = Map.Make(Nat.compare); + /// let map = natMap.fromIter(Iter.fromArray([(0, "0"), (2, "2"), (1, "1")])); + /// + /// Debug.print(debug_show(natMap.all(map, func (k, v) = (v == debug_show(k))))); + /// // true + /// Debug.print(debug_show(natMap.all(map, func (k, v) = (k < 2)))); + /// // false + /// ``` + /// + /// Runtime: `O(n)`. + /// Space: `O(1)`. + /// where `n` denotes the number of key-value entries stored in the map. + public func all(m : Map, pred : (K, V) -> Bool) : Bool + = Internal.all(m.root, pred); + + /// Test if there exists a key-value pair satisfying a given predicate `pred`. + /// + /// Example: + /// ```motoko + /// import Map "mo:base/OrderedMap"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natMap = Map.Make(Nat.compare); + /// let map = natMap.fromIter(Iter.fromArray([(0, "0"), (2, "2"), (1, "1")])); + /// + /// Debug.print(debug_show(natMap.some(map, func (k, v) = (k >= 3)))); + /// // false + /// Debug.print(debug_show(natMap.some(map, func (k, v) = (k >= 0)))); + /// // true + /// ``` + /// + /// Runtime: `O(n)`. + /// Space: `O(1)`. + /// where `n` denotes the number of key-value entries stored in the map. + public func some(m : Map, pred : (K, V) -> Bool) : Bool + = Internal.some(m.root, pred); + + /// Debug helper that check internal invariants of the given map `m`. + /// Raise an error (for a stack trace) if invariants are violated. + public func validate(m : Map) : () { + Internal.validate(m, compare); + }; + }; + + module Internal { + + public func empty() : Map { + { size = 0; root = #leaf } + }; + + public func fromIter(i : I.Iter<(K,V)>, compare : (K, K) -> O.Order) : Map { + var map = #leaf : Tree; + var size = 0; + for(val in i) { + map := put(map, compare, val.0, val.1); + size += 1; + }; + {root = map; size} + }; + + type IterRep = List.List<{ #tr : Tree; #xy : (K, V) }>; + + public func iter(map : Tree, direction : { #fwd; #bwd }) : I.Iter<(K, V)> { + let turnLeftFirst : MapTraverser = func(l, x, y, r, ts) { + ?(#tr(l), ?(#xy(x, y), ?(#tr(r), ts))) + }; + + let turnRightFirst : MapTraverser = func(l, x, y, r, ts) { + ?(#tr(r), ?(#xy(x, y), ?(#tr(l), ts))) + }; + + switch direction { + case (#fwd) IterMap(map, turnLeftFirst); + case (#bwd) IterMap(map, turnRightFirst) + } + }; + + type MapTraverser = (Tree, K, V, Tree, IterRep) -> IterRep; + + class IterMap(tree : Tree, mapTraverser : MapTraverser) { + var trees : IterRep = ?(#tr(tree), null); + public func next() : ?(K, V) { + switch (trees) { + case (null) { null }; + case (?(#tr(#leaf), ts)) { + trees := ts; + next() + }; + case (?(#xy(xy), ts)) { + trees := ts; + ?xy + }; + case (?(#tr(#red(l, x, y, r)), ts)) { + trees := mapTraverser(l, x, y, r, ts); + next() + }; + case (?(#tr(#black(l, x, y, r)), ts)) { + trees := mapTraverser(l, x, y, r, ts); + next() + } + } + } + }; + + public func map(map : Map, f : (K, V1) -> V2) : Map { + func mapRec(m : Tree) : Tree { + switch m { + case (#leaf) { #leaf }; + case (#red(l, x, y, r)) { + #red(mapRec l, x, f(x, y), mapRec r) + }; + case (#black(l, x, y, r)) { + #black(mapRec l, x, f(x, y), mapRec r) + } + } + }; + { size = map.size; root = mapRec(map.root) } + }; + + public func foldLeft( + map : Tree, + base : Accum, + combine : (Accum, Key, Value) -> Accum + ) : Accum { + switch (map) { + case (#leaf) { base }; + case (#red(l, k, v, r)) { + let left = foldLeft(l, base, combine); + let middle = combine(left, k, v); + foldLeft(r, middle, combine) + }; + case (#black(l, k, v, r)) { + let left = foldLeft(l, base, combine); + let middle = combine(left, k, v); + foldLeft(r, middle, combine) + } + } + }; + + public func foldRight( + map : Tree, + base : Accum, + combine : (Key, Value, Accum) -> Accum + ) : Accum { + switch (map) { + case (#leaf) { base }; + case (#red(l, k, v, r)) { + let right = foldRight(r, base, combine); + let middle = combine(k, v, right); + foldRight(l, middle, combine) + }; + case (#black(l, k, v, r)) { + let right = foldRight(r, base, combine); + let middle = combine(k, v, right); + foldRight(l, middle, combine) + } + } + }; + + public func mapFilter(map : Map, compare : (K, K) -> O.Order, f : (K, V1) -> ?V2) : Map { + var size = 0; + func combine(acc : Tree, key : K, value1 : V1) : Tree { + switch (f(key, value1)) { + case null { acc }; + case (?value2) { + size += 1; + put(acc, compare, key, value2) + } + } + }; + { root = foldLeft(map.root, #leaf, combine); size } + }; + + public func get(t : Tree, compare : (K, K) -> O.Order, x : K) : ?V { + switch t { + case (#red(l, x1, y1, r)) { + switch (compare(x, x1)) { + case (#less) { get(l, compare, x) }; + case (#equal) { ?y1 }; + case (#greater) { get(r, compare, x) } + } + }; + case (#black(l, x1, y1, r)) { + switch (compare(x, x1)) { + case (#less) { get(l, compare, x) }; + case (#equal) { ?y1 }; + case (#greater) { get(r, compare, x) } + } + }; + case (#leaf) { null } + } + }; + + public func contains(m : Tree, compare : (K, K) -> O.Order, key : K) : Bool { + switch (get(m, compare, key)) { + case(null) { false }; + case(_) { true } + } + }; + + public func maxEntry(m : Tree) : ?(K, V) { + func rightmost(m : Tree) : (K, V) { + switch m { + case (#red(_, k, v, #leaf)) { (k, v) }; + case (#red(_, _, _, r)) { rightmost(r) }; + case (#black(_, k, v, #leaf)) { (k, v) }; + case (#black(_, _, _, r)) { rightmost(r) }; + case (#leaf) { Debug.trap "OrderedMap.impossible" } + } + }; + switch m { + case (#leaf) { null }; + case (_) { ?rightmost(m) } + } + }; + + public func minEntry(m : Tree) : ?(K, V) { + func leftmost(m : Tree) : (K, V) { + switch m { + case (#red(#leaf, k, v, _)) { (k, v) }; + case (#red(l, _, _, _)) { leftmost(l) }; + case (#black(#leaf, k, v, _)) { (k, v) }; + case (#black(l, _, _, _)) { leftmost(l)}; + case (#leaf) { Debug.trap "OrderedMap.impossible" } + } + }; + switch m { + case (#leaf) { null }; + case (_) { ?leftmost(m) } + } + }; + + public func all(m : Tree, pred : (K, V) -> Bool) : Bool { + switch m { + case (#red(l, k, v, r)) { + pred(k, v) and all(l, pred) and all(r, pred) + }; + case (#black(l, k, v, r)) { + pred(k, v) and all(l, pred) and all(r, pred) + }; + case (#leaf) { true } + } + }; + + public func some(m : Tree, pred : (K, V) -> Bool) : Bool { + switch m { + case (#red(l, k, v, r)) { + pred(k, v) or some(l, pred) or some(r, pred) + }; + case (#black(l, k, v, r)) { + pred(k, v) or some(l, pred) or some(r, pred) + }; + case (#leaf) { false } + } + }; + + func redden(t : Tree) : Tree { + switch t { + case (#black (l, x, y, r)) { + (#red (l, x, y, r)) + }; + case _ { + Debug.trap "OrderedMap.red" + } + } + }; + + func lbalance(left : Tree, x : K, y : V, right : Tree) : Tree { + switch (left, right) { + case (#red(#red(l1, x1, y1, r1), x2, y2, r2), r) { + #red( + #black(l1, x1, y1, r1), + x2, + y2, + #black(r2, x, y, r) + ) + }; + case (#red(l1, x1, y1, #red(l2, x2, y2, r2)), r) { + #red( + #black(l1, x1, y1, l2), + x2, + y2, + #black(r2, x, y, r) + ) + }; + case _ { + #black(left, x, y, right) + } + } + }; + + func rbalance(left : Tree, x : K, y : V, right : Tree) : Tree { + switch (left, right) { + case (l, #red(l1, x1, y1, #red(l2, x2, y2, r2))) { + #red( + #black(l, x, y, l1), + x1, + y1, + #black(l2, x2, y2, r2) + ) + }; + case (l, #red(#red(l1, x1, y1, r1), x2, y2, r2)) { + #red( + #black(l, x, y, l1), + x1, + y1, + #black(r1, x2, y2, r2) + ) + }; + case _ { + #black(left, x, y, right) + } + } + }; + + type ClashResolver = { old : A; new : A } -> A; + + func insertWith( + m : Tree, + compare : (K, K) -> O.Order, + key : K, + val : V, + onClash : ClashResolver + ) : Tree { + func ins(tree : Tree) : Tree { + switch tree { + case (#black(left, x, y, right)) { + switch (compare(key, x)) { + case (#less) { + lbalance(ins left, x, y, right) + }; + case (#greater) { + rbalance(left, x, y, ins right) + }; + case (#equal) { + let newVal = onClash({ new = val; old = y }); + #black(left, key, newVal, right) + } + } + }; + case (#red(left, x, y, right)) { + switch (compare(key, x)) { + case (#less) { + #red(ins left, x, y, right) + }; + case (#greater) { + #red(left, x, y, ins right) + }; + case (#equal) { + let newVal = onClash { new = val; old = y }; + #red(left, key, newVal, right) + } + } + }; + case (#leaf) { + #red(#leaf, key, val, #leaf) + } + } + }; + switch (ins m) { + case (#red(left, x, y, right)) { + #black(left, x, y, right) + }; + case other { other } + } + }; + + public func replace( + m : Tree, + compare : (K, K) -> O.Order, + key : K, + val : V + ) : (Tree, ?V) { + var oldVal : ?V = null; + func onClash(clash : { old : V; new : V }) : V { + oldVal := ?clash.old; + clash.new + }; + let res = insertWith(m, compare, key, val, onClash); + (res, oldVal) + }; + + public func put( + m : Tree, + compare : (K, K) -> O.Order, + key : K, + val : V + ) : Tree = replace(m, compare, key, val).0; + + func balLeft(left : Tree, x : K, y : V, right : Tree) : Tree { + switch (left, right) { + case (#red(l1, x1, y1, r1), r) { + #red( + #black(l1, x1, y1, r1), + x, + y, + r + ) + }; + case (_, #black(l2, x2, y2, r2)) { + rbalance(left, x, y, #red(l2, x2, y2, r2)) + }; + case (_, #red(#black(l2, x2, y2, r2), x3, y3, r3)) { + #red( + #black(left, x, y, l2), + x2, + y2, + rbalance(r2, x3, y3, redden r3) + ) + }; + case _ { Debug.trap "balLeft" } + } + }; + + func balRight(left : Tree, x : K, y : V, right : Tree) : Tree { + switch (left, right) { + case (l, #red(l1, x1, y1, r1)) { + #red( + l, + x, + y, + #black(l1, x1, y1, r1) + ) + }; + case (#black(l1, x1, y1, r1), r) { + lbalance(#red(l1, x1, y1, r1), x, y, r) + }; + case (#red(l1, x1, y1, #black(l2, x2, y2, r2)), r3) { + #red( + lbalance(redden l1, x1, y1, l2), + x2, + y2, + #black(r2, x, y, r3) + ) + }; + case _ { Debug.trap "balRight" } + } + }; + + func append(left : Tree, right : Tree) : Tree { + switch (left, right) { + case (#leaf, _) { right }; + case (_, #leaf) { left }; + case ( + #red(l1, x1, y1, r1), + #red(l2, x2, y2, r2) + ) { + switch (append(r1, l2)) { + case (#red(l3, x3, y3, r3)) { + #red( + #red(l1, x1, y1, l3), + x3, + y3, + #red(r3, x2, y2, r2) + ) + }; + case r1l2 { + #red(l1, x1, y1, #red(r1l2, x2, y2, r2)) + } + } + }; + case (t1, #red(l2, x2, y2, r2)) { + #red(append(t1, l2), x2, y2, r2) + }; + case (#red(l1, x1, y1, r1), t2) { + #red(l1, x1, y1, append(r1, t2)) + }; + case (#black(l1, x1, y1, r1), #black(l2, x2, y2, r2)) { + switch (append(r1, l2)) { + case (#red(l3, x3, y3, r3)) { + #red( + #black(l1, x1, y1, l3), + x3, + y3, + #black(r3, x2, y2, r2) + ) + }; + case r1l2 { + balLeft( + l1, + x1, + y1, + #black(r1l2, x2, y2, r2) + ) + } + } + } + } + }; + + public func delete(m : Tree, compare : (K, K) -> O.Order, key : K) : Tree + = remove(m, compare, key).0; + + public func remove(tree : Tree, compare : (K, K) -> O.Order, x : K) : (Tree, ?V) { + var y0 : ?V = null; + func delNode(left : Tree, x1 : K, y1 : V, right : Tree) : Tree { + switch (compare(x, x1)) { + case (#less) { + let newLeft = del left; + switch left { + case (#black(_, _, _, _)) { + balLeft(newLeft, x1, y1, right) + }; + case _ { + #red(newLeft, x1, y1, right) + } + } + }; + case (#greater) { + let newRight = del right; + switch right { + case (#black(_, _, _, _)) { + balRight(left, x1, y1, newRight) + }; + case _ { + #red(left, x1, y1, newRight) + } + } + }; + case (#equal) { + y0 := ?y1; + append(left, right) + } + } + }; + func del(tree : Tree) : Tree { + switch tree { + case (#red(left, x, y, right)) { + delNode(left, x, y, right) + }; + case (#black(left, x, y, right)) { + delNode(left, x, y, right) + }; + case (#leaf) { + tree + } + } + }; + switch (del(tree)) { + case (#red(left, x, y, right)) { + (#black(left, x, y, right), y0) + }; + case other { (other, y0) } + } + }; + + // Test helper + public func validate(rbMap : Map, comp : (K, K) -> O.Order) { + ignore blackDepth(rbMap.root, comp) + }; + + func blackDepth(node : Tree, comp : (K, K) -> O.Order) : Nat { + func checkNode(left : Tree, key : K, right : Tree) : Nat { + checkKey(left, func(x : K) : Bool { comp(x, key) == #less }); + checkKey(right, func(x : K) : Bool { comp(x, key) == #greater }); + let leftBlacks = blackDepth(left, comp); + let rightBlacks = blackDepth(right, comp); + assert (leftBlacks == rightBlacks); + leftBlacks + }; + switch node { + case (#leaf) 0; + case (#red(left, key, _, right)) { + let leftBlacks = checkNode(left, key, right); + assert (not isRed(left)); + assert (not isRed(right)); + leftBlacks + }; + case (#black(left, key, _, right)) { + checkNode(left, key, right) + 1 + } + } + }; + + func isRed(node : Tree) : Bool { + switch node { + case (#red(_, _, _, _)) true; + case _ false + } + }; + + func checkKey(node : Tree, isValid : K -> Bool) { + switch node { + case (#leaf) {}; + case (#red(_, key, _, _)) { + assert (isValid(key)) + }; + case (#black(_, key, _, _)) { + assert (isValid(key)) + } + } + }; + }; + + /// Create `OrderedMap.Operations` object capturing key type `K` and `compare` function. + /// It is an alias for the `Operations` constructor. + /// + /// Example: + /// ```motoko + /// import Map "mo:base/OrderedMap"; + /// import Nat "mo:base/Nat"; + /// + /// actor { + /// let natMap = Map.Make(Nat.compare); + /// stable var map : Map.Map = natMap.empty(); + /// }; + /// ``` + public let Make : (compare : (K, K) -> O.Order) -> Operations = Operations +} diff --git a/src/OrderedSet.mo b/src/OrderedSet.mo new file mode 100644 index 00000000..651441a7 --- /dev/null +++ b/src/OrderedSet.mo @@ -0,0 +1,1226 @@ +/// Stable ordered set implemented as a red-black tree. +/// +/// A red-black tree is a balanced binary search tree ordered by the elements. +/// +/// The tree data structure internally colors each of its nodes either red or black, +/// and uses this information to balance the tree during the modifying operations. +/// +/// Performance: +/// * Runtime: `O(log(n))` worst case cost per insertion, removal, and retrieval operation. +/// * Space: `O(n)` for storing the entire tree. +/// `n` denotes the number of elements (i.e. nodes) stored in the tree. +/// +/// Credits: +/// +/// The core of this implementation is derived from: +/// +/// * Ken Friis Larsen's [RedBlackMap.sml](https://github.com/kfl/mosml/blob/master/src/mosmllib/Redblackmap.sml), which itself is based on: +/// * Stefan Kahrs, "Red-black trees with types", Journal of Functional Programming, 11(4): 425-432 (2001), [version 1 in web appendix](http://www.cs.ukc.ac.uk/people/staff/smk/redblack/rb.html). + +import Debug "Debug"; +import Buffer "Buffer"; +import I "Iter"; +import List "List"; +import Nat "Nat"; +import O "Order"; + +module { + /// Red-black tree of nodes with ordered set elements. + /// Leaves are considered implicitly black. + type Tree = { + #red : (Tree, T, Tree); + #black : (Tree, T, Tree); + #leaf + }; + + /// Ordered collection of unique elements of the generic type `T`. + /// If type `T` is stable then `Set` is also stable. + /// To ensure that property the `Set` does not have any methods, + /// instead they are gathered in the functor-like class `Operations` (see example there). + public type Set = { size : Nat; root : Tree }; + + /// Class that captures element type `T` along with its ordering function `compare` + /// and provide all operations to work with a set of type `Set`. + /// + /// An instance object should be created once as a canister field to ensure + /// that the same ordering function is used for every operation. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/OrderedSet"; + /// import Nat "mo:base/Nat"; + /// + /// actor { + /// let natSet = Set.Make(Nat.compare); // : Operations + /// stable var usedIds : Set.Set = natSet.empty(); + /// + /// public func createId(id : Nat) : async () { + /// usedIds := natSet.put(usedIds, id); + /// }; + /// + /// public func idIsUsed(id: Nat) : async Bool { + /// natSet.contains(usedIds, id) + /// } + /// } + /// ``` + public class Operations(compare : (T, T) -> O.Order) { + + /// Returns a new Set, containing all entries given by the iterator `i`. + /// If there are multiple identical entries only one is taken. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/OrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natSet = Set.Make(Nat.compare); + /// let set = natSet.fromIter(Iter.fromArray([0, 2, 1])); + /// + /// Debug.print(debug_show(Iter.toArray(natSet.vals(set)))); + /// // [0, 1, 2] + /// ``` + /// + /// Runtime: `O(n * log(n))`. + /// Space: `O(n)` retained memory plus garbage, see the note below. + /// where `n` denotes the number of elements stored in the set and + /// assuming that the `compare` function implements an `O(1)` comparison. + /// + /// Note: Creates `O(n * log(n))` temporary objects that will be collected as garbage. + public func fromIter(i : I.Iter) : Set { + var set = empty() : Set; + for (val in i) { + set := Internal.put(set, compare, val) + }; + set + }; + + /// Insert the value `value` into the set `s`. Has no effect if `value` is already + /// present in the set. Returns a modified set. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/OrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natSet = Set.Make(Nat.compare); + /// var set = natSet.empty(); + /// + /// set := natSet.put(set, 0); + /// set := natSet.put(set, 2); + /// set := natSet.put(set, 1); + /// + /// Debug.print(debug_show(Iter.toArray(natSet.vals(set)))); + /// // [0, 1, 2] + /// ``` + /// + /// Runtime: `O(log(n))`. + /// Space: `O(log(n))`. + /// where `n` denotes the number of elements stored in the set and + /// assuming that the `compare` function implements an `O(1)` comparison. + /// + /// Note: The returned set shares with the `s` most of the tree nodes. + /// Garbage collecting one of sets (e.g. after an assignment `m := natSet.delete(m, k)`) + /// causes collecting `O(log(n))` nodes. + public func put(s : Set, value : T) : Set + = Internal.put(s, compare, value); + + /// Deletes the value `value` from the set `s`. Has no effect if `value` is not + /// present in the set. Returns modified set. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/OrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natSet = Set.Make(Nat.compare); + /// let set = natSet.fromIter(Iter.fromArray([0, 2, 1])); + /// + /// Debug.print(debug_show(Iter.toArray(natSet.vals(natSet.delete(set, 1))))); + /// Debug.print(debug_show(Iter.toArray(natSet.vals(natSet.delete(set, 42))))); + /// // [0, 2] + /// // [0, 1, 2] + /// ``` + /// + /// Runtime: `O(log(n))`. + /// Space: `O(log(n))`. + /// where `n` denotes the number of elements stored in the set and + /// assuming that the `compare` function implements an `O(1)` comparison. + /// + /// Note: The returned set shares with the `s` most of the tree nodes. + /// Garbage collecting one of sets (e.g. after an assignment `m := natSet.delete(m, k)`) + /// causes collecting `O(log(n))` nodes. + public func delete(s : Set, value : T) : Set + = Internal.delete(s, compare, value); + + /// Test if the set 's' contains a given element. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/OrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natSet = Set.Make(Nat.compare); + /// let set = natSet.fromIter(Iter.fromArray([0, 2, 1])); + /// + /// Debug.print(debug_show natSet.contains(set, 1)); // => true + /// Debug.print(debug_show natSet.contains(set, 42)); // => false + /// ``` + /// + /// Runtime: `O(log(n))`. + /// Space: `O(1)` retained memory plus garbage, see the note below. + /// where `n` denotes the number of elements stored in the set and + /// assuming that the `compare` function implements an `O(1)` comparison. + public func contains(s : Set, value : T) : Bool + = Internal.contains(s.root, compare, value); + + /// Get a maximal element of the set `s` if it is not empty, otherwise returns `null` + /// + /// Example: + /// ```motoko + /// import Set "mo:base/OrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natSet = Set.Make(Nat.compare); + /// let s1 = natSet.fromIter(Iter.fromArray([0, 2, 1])); + /// let s2 = natSet.empty(); + /// + /// Debug.print(debug_show(natSet.max(s1))); // => ?2 + /// Debug.print(debug_show(natSet.max(s2))); // => null + /// ``` + /// + /// Runtime: `O(log(n))`. + /// Space: `O(1)`. + /// where `n` denotes the number of elements in the set + public func max(s : Set) : ?T + = Internal.max(s.root); + + /// Get a minimal element of the set `s` if it is not empty, otherwise returns `null` + /// + /// Example: + /// ```motoko + /// import Set "mo:base/OrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natSet = Set.Make(Nat.compare); + /// let s1 = natSet.fromIter(Iter.fromArray([0, 2, 1])); + /// let s2 = natSet.empty(); + /// + /// Debug.print(debug_show(natSet.min(s1))); // => ?0 + /// Debug.print(debug_show(natSet.min(s2))); // => null + /// ``` + /// + /// Runtime: `O(log(n))`. + /// Space: `O(1)`. + /// where `n` denotes the number of elements in the set + public func min(s : Set) : ?T + = Internal.min(s.root); + + /// [Set union](https://en.wikipedia.org/wiki/Union_(set_theory)) operation. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/OrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natSet = Set.Make(Nat.compare); + /// let set1 = natSet.fromIter(Iter.fromArray([0, 1, 2])); + /// let set2 = natSet.fromIter(Iter.fromArray([2, 3, 4])); + /// + /// Debug.print(debug_show Iter.toArray(natSet.vals(natSet.union(set1, set2)))); + /// // [0, 1, 2, 3, 4] + /// ``` + /// + /// Runtime: `O(m * log(n))`. + /// Space: `O(m)`, retained memory plus garbage, see the note below. + /// where `m` and `n` denote the number of elements in the sets, and `m <= n`. + /// + /// Note: Creates `O(m * log(n))` temporary objects that will be collected as garbage. + public func union(s1 : Set, s2 : Set) : Set { + if (size(s1) < size(s2)) { + foldLeft(s1, s2, func(acc : Set, elem : T) : Set { Internal.put(acc, compare, elem) }) + } else { + foldLeft(s2, s1, func(acc : Set, elem : T) : Set { Internal.put(acc, compare, elem) }) + } + }; + + /// [Set intersection](https://en.wikipedia.org/wiki/Intersection_(set_theory)) operation. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/OrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natSet = Set.Make(Nat.compare); + /// let set1 = natSet.fromIter(Iter.fromArray([0, 1, 2])); + /// let set2 = natSet.fromIter(Iter.fromArray([1, 2, 3])); + /// + /// Debug.print(debug_show Iter.toArray(natSet.vals(natSet.intersect(set1, set2)))); + /// // [1, 2] + /// ``` + /// + /// Runtime: `O(m * log(n))`. + /// Space: `O(m)`, retained memory plus garbage, see the note below. + /// where `m` and `n` denote the number of elements in the sets, and `m <= n`. + /// + /// Note: Creates `O(m)` temporary objects that will be collected as garbage. + public func intersect(s1 : Set, s2 : Set) : Set { + let elems = Buffer.Buffer(Nat.min(Nat.min(s1.size, s2.size), 100)); + if (s1.size < s2.size) { + Internal.iterate(s1.root, func (x: T) { + if (Internal.contains(s2.root, compare, x)) { + elems.add(x) + } + }); + } else { + Internal.iterate(s2.root, func (x: T) { + if (Internal.contains(s1.root, compare, x)) { + elems.add(x) + } + }); + }; + { root = Internal.buildFromSorted(elems); size = elems.size() } + }; + + /// [Set difference](https://en.wikipedia.org/wiki/Difference_(set_theory)). + /// + /// Example: + /// ```motoko + /// import Set "mo:base/OrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natSet = Set.Make(Nat.compare); + /// let set1 = natSet.fromIter(Iter.fromArray([0, 1, 2])); + /// let set2 = natSet.fromIter(Iter.fromArray([1, 2, 3])); + /// + /// Debug.print(debug_show Iter.toArray(natSet.vals(natSet.diff(set1, set2)))); + /// // [0] + /// ``` + /// + /// Runtime: `O(m * log(n))`. + /// Space: `O(m)`, retained memory plus garbage, see the note below. + /// where `m` and `n` denote the number of elements in the sets, and `m <= n`. + /// + /// Note: Creates `O(m * log(n))` temporary objects that will be collected as garbage. + public func diff(s1 : Set, s2 : Set) : Set { + if (size(s1) < size(s2)) { + let elems = Buffer.Buffer(Nat.min(s1.size, 100)); + Internal.iterate(s1.root, func (x : T) { + if (not Internal.contains(s2.root, compare, x)) { + elems.add(x) + } + } + ); + { root = Internal.buildFromSorted(elems); size = elems.size() } + } + else { + foldLeft(s2, s1, + func (acc : Set, elem : T) : Set { + if (Internal.contains(acc.root, compare, elem)) { Internal.delete(acc, compare, elem) } else { acc } + } + ) + } + }; + + /// Creates a new `Set` by applying `f` to each entry in the set `s`. Each element + /// `x` in the old set is transformed into a new entry `x2`, where + /// the new value `x2` is created by applying `f` to `x`. + /// The result set may be smaller than the original set due to duplicate elements. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/OrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natSet = Set.Make(Nat.compare); + /// let set = natSet.fromIter(Iter.fromArray([0, 1, 2, 3])); + /// + /// func f(x : Nat) : Nat = if (x < 2) { x } else { 0 }; + /// + /// let resSet = natSet.map(set, f); + /// + /// Debug.print(debug_show(Iter.toArray(natSet.vals(resSet)))); + /// // [0, 1] + /// ``` + /// + /// Cost of mapping all the elements: + /// Runtime: `O(n * log(n))`. + /// Space: `O(n)` retained memory + /// where `n` denotes the number of elements stored in the set. + /// + /// Note: Creates `O(n * log(n))` temporary objects that will be collected as garbage. + public func map(s : Set, f : T1 -> T) : Set + = Internal.foldLeft(s.root, empty(), func (acc : Set, elem : T1) : Set { Internal.put(acc, compare, f(elem)) }); + + /// Creates a new set by applying `f` to each element in the set `s`. For each element + /// `x` in the old set, if `f` evaluates to `null`, the element is discarded. + /// Otherwise, the entry is transformed into a new entry `x2`, where + /// the new value `x2` is the result of applying `f` to `x`. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/OrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natSet = Set.Make(Nat.compare); + /// let set = natSet.fromIter(Iter.fromArray([0, 1, 2, 3])); + /// + /// func f(x : Nat) : ?Nat { + /// if(x == 0) {null} + /// else { ?( x * 2 )} + /// }; + /// + /// let newRbSet = natSet.mapFilter(set, f); + /// + /// Debug.print(debug_show(Iter.toArray(natSet.vals(newRbSet)))); + /// // [2, 4, 6] + /// ``` + /// + /// Runtime: `O(n * log(n))`. + /// Space: `O(n)` retained memory plus garbage, see the note below. + /// where `n` denotes the number of elements stored in the set and + /// assuming that the `compare` function implements an `O(1)` comparison. + /// + /// Note: Creates `O(n * log(n))` temporary objects that will be collected as garbage. + public func mapFilter(s : Set, f : T1 -> ?T) : Set { + func combine(acc : Set, elem : T1) : Set { + switch (f(elem)) { + case null { acc }; + case (?elem2) { + Internal.put(acc, compare, elem2) + } + } + }; + Internal.foldLeft(s.root, empty(), combine) + }; + + /// Test if `set1` is subset of `set2`. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/OrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natSet = Set.Make(Nat.compare); + /// let set1 = natSet.fromIter(Iter.fromArray([1, 2])); + /// let set2 = natSet.fromIter(Iter.fromArray([0, 2, 1])); + /// + /// Debug.print(debug_show natSet.isSubset(set1, set2)); // => true + /// ``` + /// + /// Runtime: `O(m * log(n))`. + /// Space: `O(1)` retained memory plus garbage, see the note below. + /// where `m` and `n` denote the number of elements stored in the sets set1 and set2, respectively, + /// and assuming that the `compare` function implements an `O(1)` comparison. + public func isSubset(s1 : Set, s2 : Set) : Bool { + if (s1.size > s2.size) { return false }; + isSubsetHelper(s1.root, s2.root) + }; + + /// Test if two sets are equal. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/OrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natSet = Set.Make(Nat.compare); + /// let set1 = natSet.fromIter(Iter.fromArray([0, 2, 1])); + /// let set2 = natSet.fromIter(Iter.fromArray([1, 2])); + /// + /// Debug.print(debug_show natSet.equals(set1, set1)); // => true + /// Debug.print(debug_show natSet.equals(set1, set2)); // => false + /// ``` + /// + /// Runtime: `O(m * log(n))`. + /// Space: `O(1)` retained memory plus garbage, see the note below. + /// where `m` and `n` denote the number of elements stored in the sets set1 and set2, respectively, + /// and assuming that the `compare` function implements an `O(1)` comparison. + public func equals(s1 : Set, s2 : Set) : Bool { + if (s1.size != s2.size) { return false }; + isSubsetHelper(s1.root, s2.root) + }; + + func isSubsetHelper(t1 : Tree, t2 : Tree) : Bool { + switch (t1, t2) { + case (#leaf, _) { true }; + case (_, #leaf) { false }; + case ((#red(t1l, x1, t1r) or #black(t1l, x1, t1r)), (#red(t2l, x2, t2r)) or #black(t2l, x2, t2r)) { + switch (compare(x1, x2)) { + case (#equal) { isSubsetHelper(t1l, t2l) and isSubsetHelper(t1r, t2r) }; + // x1 < x2 ==> x1 \in t2l /\ t1l \subset t2l + case (#less) { Internal.contains(t2l, compare, x1) and isSubsetHelper(t1l, t2l) and isSubsetHelper(t1r, t2) }; + // x2 < x1 ==> x1 \in t2r /\ t1r \subset t2r + case (#greater) { Internal.contains(t2r, compare, x1) and isSubsetHelper(t1l, t2) and isSubsetHelper(t1r, t2r) } + } + } + } + }; + + /// Returns an Iterator (`Iter`) over the elements of the set. + /// Iterator provides a single method `next()`, which returns + /// elements in ascending order, or `null` when out of elements to iterate over. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/OrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natSet = Set.Make(Nat.compare); + /// let set = natSet.fromIter(Iter.fromArray([0, 2, 1])); + /// + /// Debug.print(debug_show(Iter.toArray(natSet.vals(set)))); + /// // [0, 1, 2] + /// ``` + /// Cost of iteration over all elements: + /// Runtime: `O(n)`. + /// Space: `O(log(n))` retained memory plus garbage, see the note below. + /// where `n` denotes the number of elements stored in the set. + /// + /// Note: Full set iteration creates `O(n)` temporary objects that will be collected as garbage. + public func vals(s : Set) : I.Iter + = Internal.iter(s.root, #fwd); + + /// Same as `vals()` but iterates over elements of the set `s` in the descending order. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/OrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natSet = Set.Make(Nat.compare); + /// let set = natSet.fromIter(Iter.fromArray([0, 2, 1])); + /// + /// Debug.print(debug_show(Iter.toArray(natSet.valsRev(set)))); + /// // [2, 1, 0] + /// ``` + /// Cost of iteration over all elements: + /// Runtime: `O(n)`. + /// Space: `O(log(n))` retained memory plus garbage, see the note below. + /// where `n` denotes the number of elements stored in the set. + /// + /// Note: Full set iteration creates `O(n)` temporary objects that will be collected as garbage. + public func valsRev(s : Set) : I.Iter + = Internal.iter(s.root, #bwd); + + /// Create a new empty Set. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/OrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Debug "mo:base/Debug"; + /// + /// let natSet = Set.Make(Nat.compare); + /// let set = natSet.empty(); + /// + /// Debug.print(debug_show(natSet.size(set))); // => 0 + /// ``` + /// + /// Cost of empty set creation + /// Runtime: `O(1)`. + /// Space: `O(1)` + public func empty() : Set + = { root = #leaf; size = 0}; + + /// Returns the number of elements in the set. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/OrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natSet = Set.Make(Nat.compare); + /// let set = natSet.fromIter(Iter.fromArray([0, 2, 1])); + /// + /// Debug.print(debug_show(natSet.size(set))); // => 3 + /// ``` + /// + /// Runtime: `O(1)`. + /// Space: `O(1)`. + public func size(s : Set) : Nat + = s.size; + + /// Collapses the elements in `set` into a single value by starting with `base` + /// and progessively combining elements into `base` with `combine`. Iteration runs + /// left to right. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/OrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natSet = Set.Make(Nat.compare); + /// let set = natSet.fromIter(Iter.fromArray([0, 2, 1])); + /// + /// func folder(accum : Nat, val : Nat) : Nat = val + accum; + /// + /// Debug.print(debug_show(natSet.foldLeft(set, 0, folder))); + /// // 3 + /// ``` + /// + /// Cost of iteration over all elements: + /// Runtime: `O(n)`. + /// Space: depends on `combine` function plus garbage, see the note below. + /// where `n` denotes the number of elements stored in the set. + /// + /// Note: Full set iteration creates `O(n)` temporary objects that will be collected as garbage. + public func foldLeft( + set : Set, + base : Accum, + combine : (Accum, T) -> Accum + ) : Accum + = Internal.foldLeft(set.root, base, combine); + + /// Collapses the elements in `set` into a single value by starting with `base` + /// and progessively combining elements into `base` with `combine`. Iteration runs + /// right to left. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/OrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natSet = Set.Make(Nat.compare); + /// let set = natSet.fromIter(Iter.fromArray([0, 2, 1])); + /// + /// func folder(val : Nat, accum : Nat) : Nat = val + accum; + /// + /// Debug.print(debug_show(natSet.foldRight(set, 0, folder))); + /// // 3 + /// ``` + /// + /// Cost of iteration over all elements: + /// Runtime: `O(n)`. + /// Space: depends on `combine` function plus garbage, see the note below. + /// where `n` denotes the number of elements stored in the set. + /// + /// Note: Full set iteration creates `O(n)` temporary objects that will be collected as garbage. + public func foldRight( + set : Set, + base : Accum, + combine : (T, Accum) -> Accum + ) : Accum + = Internal.foldRight(set.root, base, combine); + + /// Test if the given set `s` is empty. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/OrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Debug "mo:base/Debug"; + /// + /// let natSet = Set.Make(Nat.compare); + /// let set = natSet.empty(); + /// + /// Debug.print(debug_show(natSet.isEmpty(set))); // => true + /// ``` + /// + /// Runtime: `O(1)`. + /// Space: `O(1)`. + public func isEmpty(s : Set) : Bool { + switch (s.root) { + case (#leaf) { true }; + case _ { false } + } + }; + + /// Test whether all values in the set `s` satisfy a given predicate `pred`. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/OrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natSet = Set.Make(Nat.compare); + /// let set = natSet.fromIter(Iter.fromArray([0, 2, 1])); + /// + /// Debug.print(debug_show(natSet.all(set, func (v) = (v < 10)))); + /// // true + /// Debug.print(debug_show(natSet.all(set, func (v) = (v < 2)))); + /// // false + /// ``` + /// + /// Runtime: `O(n)`. + /// Space: `O(1)`. + /// where `n` denotes the number of elements stored in the set. + public func all(s : Set, pred : T -> Bool) : Bool + = Internal.all(s.root, pred); + + /// Test if there exists an element in the set `s` satisfying the given predicate `pred`. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/OrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// import Debug "mo:base/Debug"; + /// + /// let natSet = Set.Make(Nat.compare); + /// let set = natSet.fromIter(Iter.fromArray([0, 2, 1])); + /// + /// Debug.print(debug_show(natSet.some(set, func (v) = (v >= 3)))); + /// // false + /// Debug.print(debug_show(natSet.some(set, func (v) = (v >= 0)))); + /// // true + /// ``` + /// + /// Runtime: `O(n)`. + /// Space: `O(1)`. + /// where `n` denotes the number of elements stored in the set. + public func some(s : Set, pred : (T) -> Bool) : Bool + = Internal.some(s.root, pred); + + /// Test helper that check internal invariant for the given set `s`. + /// Raise an error (for a stack trace) if invariants are violated. + public func validate(s : Set): () { + Internal.validate(s, compare); + } + }; + + module Internal { + public func contains(tree : Tree, compare : (T, T) -> O.Order, elem : T) : Bool { + func f(t : Tree, x : T) : Bool { + switch t { + case (#black(l, x1, r)) { + switch (compare(x, x1)) { + case (#less) { f(l, x) }; + case (#equal) { true }; + case (#greater) { f(r, x) } + } + }; + case (#red(l, x1, r)) { + switch (compare(x, x1)) { + case (#less) { f(l, x) }; + case (#equal) { true }; + case (#greater) { f(r, x) } + } + }; + case (#leaf) { false } + } + }; + f(tree, elem) + }; + + public func max(m : Tree) : ?V { + func rightmost(m : Tree) : V { + switch m { + case (#red(_, v, #leaf)) { v }; + case (#red(_, _, r)) { rightmost(r) }; + case (#black(_, v, #leaf)) { v }; + case (#black(_, _, r)) { rightmost(r) }; + case (#leaf) { Debug.trap "OrderedSet.impossible" } + } + }; + switch m { + case (#leaf) { null }; + case (_) { ?rightmost(m) } + } + }; + + public func min(m : Tree) : ?V { + func leftmost(m : Tree) : V { + switch m { + case (#red(#leaf, v, _)) { v }; + case (#red(l, _, _)) { leftmost(l) }; + case (#black(#leaf, v, _)) { v }; + case (#black(l, _, _)) { leftmost(l)}; + case (#leaf) { Debug.trap "OrderedSet.impossible" } + } + }; + switch m { + case (#leaf) { null }; + case (_) { ?leftmost(m) } + } + }; + + public func all(m : Tree, pred : V -> Bool) : Bool { + switch m { + case (#red(l, v, r)) { + pred(v) and all(l, pred) and all(r, pred) + }; + case (#black(l, v, r)) { + pred(v) and all(l, pred) and all(r, pred) + }; + case (#leaf) { true } + } + }; + + public func some(m : Tree, pred : V -> Bool) : Bool { + switch m { + case (#red(l, v, r)) { + pred(v) or some(l, pred) or some(r, pred) + }; + case (#black(l, v, r)) { + pred(v) or some(l, pred) or some(r, pred) + }; + case (#leaf) { false } + } + }; + + public func iterate(m : Tree, f : V -> ()) { + switch m { + case (#leaf) { }; + case (#black(l, v, r)) { iterate(l, f); f(v); iterate(r, f) }; + case (#red(l, v, r)) { iterate(l, f); f(v); iterate(r, f) } + } + }; + + // build tree from elements arr[l]..arr[r-1] + public func buildFromSorted(buf : Buffer.Buffer) : Tree { + var maxDepth = 0; + var maxSize = 1; + while (maxSize < buf.size()) { + maxDepth += 1; + maxSize += maxSize + 1; + }; + maxDepth := if (maxDepth == 0) {1} else {maxDepth}; // keep root black for 1 element tree + func buildFromSortedHelper(l : Nat, r : Nat, depth : Nat) : Tree { + if (l + 1 == r) { + if (depth == maxDepth) { + return #red(#leaf, buf.get(l), #leaf); + } else { + return #black(#leaf, buf.get(l), #leaf); + } + }; + if (l >= r) { + return #leaf; + }; + let m = (l + r) / 2; + return #black( + buildFromSortedHelper(l, m, depth+1), + buf.get(m), + buildFromSortedHelper(m+1, r, depth+1) + ) + }; + buildFromSortedHelper(0, buf.size(), 0); + }; + + type IterRep = List.List<{ #tr : Tree; #x : T }>; + + type SetTraverser = (Tree, T, Tree, IterRep) -> IterRep; + + class IterSet(tree : Tree, setTraverser : SetTraverser) { + var trees : IterRep = ?(#tr(tree), null); + public func next() : ?T { + switch (trees) { + case (null) { null }; + case (?(#tr(#leaf), ts)) { + trees := ts; + next() + }; + case (?(#x(x), ts)) { + trees := ts; + ?x + }; + case (?(#tr(#black(l, x, r)), ts)) { + trees := setTraverser(l, x, r, ts); + next() + }; + case (?(#tr(#red(l, x, r)), ts)) { + trees := setTraverser(l, x, r, ts); + next() + } + } + } + }; + + public func iter(s : Tree, direction : {#fwd; #bwd}) : I.Iter { + let turnLeftFirst : SetTraverser + = func (l, x, r, ts) { ?(#tr(l), ?(#x(x), ?(#tr(r), ts))) }; + + let turnRightFirst : SetTraverser + = func (l, x, r, ts) { ?(#tr(r), ?(#x(x), ?(#tr(l), ts))) }; + + switch direction { + case (#fwd) IterSet(s, turnLeftFirst); + case (#bwd) IterSet(s, turnRightFirst) + } + }; + + public func foldLeft( + tree : Tree, + base : Accum, + combine : (Accum, T) -> Accum + ) : Accum { + switch (tree) { + case (#leaf) { base }; + case (#black(l, x, r)) { + let left = foldLeft(l, base, combine); + let middle = combine(left, x); + foldLeft(r, middle, combine) + }; + case (#red(l, x, r)) { + let left = foldLeft(l, base, combine); + let middle = combine(left, x); + foldLeft(r, middle, combine) + } + } + }; + + public func foldRight( + tree : Tree, + base : Accum, + combine : (T, Accum) -> Accum + ) : Accum { + switch (tree) { + case (#leaf) { base }; + case (#black(l, x, r)) { + let right = foldRight(r, base, combine); + let middle = combine(x, right); + foldRight(l, middle, combine) + }; + case (#red(l, x, r)) { + let right = foldRight(r, base, combine); + let middle = combine(x, right); + foldRight(l, middle, combine) + } + } + }; + + func redden(t : Tree) : Tree { + switch t { + case (#black(l, x, r)) { + (#red (l, x, r)) + }; + case _ { + Debug.trap "OrderedSet.red" + } + } + }; + + func lbalance(left : Tree, x : T, right : Tree) : Tree { + switch (left, right) { + case (#red(#red(l1, x1, r1), x2, r2), r) { + #red( + #black(l1, x1, r1), + x2, + #black(r2, x, r) + ) + }; + case (#red(l1, x1, #red(l2, x2, r2)), r) { + #red( + #black(l1, x1, l2), + x2, + #black(r2, x, r) + ) + }; + case _ { + #black(left, x, right) + } + } + }; + + func rbalance(left : Tree, x : T, right : Tree) : Tree { + switch (left, right) { + case (l, #red(l1, x1, #red(l2, x2, r2))) { + #red( + #black(l, x, l1), + x1, + #black(l2, x2, r2) + ) + }; + case (l, #red(#red(l1, x1, r1), x2, r2)) { + #red( + #black(l, x, l1), + x1, + #black(r1, x2, r2) + ) + }; + case _ { + #black(left, x, right) + } + } + }; + + public func put( + s : Set, + compare : (T, T) -> O.Order, + elem : T + ) : Set { + var newNodeIsCreated : Bool = false; + func ins(tree : Tree) : Tree { + switch tree { + case (#black(left, x, right)) { + switch (compare(elem, x)) { + case (#less) { + lbalance(ins left, x, right) + }; + case (#greater) { + rbalance(left, x, ins right) + }; + case (#equal) { + #black(left, x, right) + } + } + }; + case (#red(left, x, right)) { + switch (compare(elem, x)) { + case (#less) { + #red(ins left, x, right) + }; + case (#greater) { + #red(left, x, ins right) + }; + case (#equal) { + #red(left, x, right) + } + } + }; + case (#leaf) { + newNodeIsCreated := true; + #red(#leaf, elem, #leaf) + } + } + }; + let newRoot = switch (ins(s.root)) { + case (#red(left, x, right)) { + #black(left, x, right) + }; + case other { other } + }; + { root = newRoot; + size = if newNodeIsCreated { s.size + 1 } else { s.size } } + }; + + func balLeft(left : Tree, x : T, right : Tree) : Tree { + switch (left, right) { + case (#red(l1, x1, r1), r) { + #red(#black(l1, x1, r1), x, r) + }; + case (_, #black(l2, x2, r2)) { + rbalance(left, x, #red(l2, x2, r2)) + }; + case (_, #red(#black(l2, x2, r2), x3, r3)) { + #red( + #black(left, x, l2), + x2, + rbalance(r2, x3, redden r3) + ) + }; + case _ { Debug.trap "balLeft" } + } + }; + + func balRight(left : Tree, x : T, right : Tree) : Tree { + switch (left, right) { + case (l, #red(l1, x1, r1)) { + #red(l, x, #black(l1, x1, r1)) + }; + case (#black(l1, x1, r1), r) { + lbalance(#red(l1, x1, r1), x, r) + }; + case (#red(l1, x1, #black(l2, x2, r2)), r3) { + #red( + lbalance(redden l1, x1, l2), + x2, + #black(r2, x, r3) + ) + }; + case _ { Debug.trap "balRight" } + } + }; + + func append(left : Tree, right : Tree) : Tree { + switch (left, right) { + case (#leaf, _) { right }; + case (_, #leaf) { left }; + case ( + #red(l1, x1, r1), + #red(l2, x2, r2) + ) { + switch (append(r1, l2)) { + case (#red(l3, x3, r3)) { + #red( + #red(l1, x1, l3), + x3, + #red(r3, x2, r2) + ) + }; + case r1l2 { + #red(l1, x1, #red(r1l2, x2, r2)) + } + } + }; + case (t1, #red(l2, x2, r2)) { + #red(append(t1, l2), x2, r2) + }; + case (#red(l1, x1, r1), t2) { + #red(l1, x1, append(r1, t2)) + }; + case (#black(l1, x1, r1), #black(l2, x2, r2)) { + switch (append(r1, l2)) { + case (#red(l3, x3, r3)) { + #red( + #black(l1, x1, l3), + x3, + #black(r3, x2, r2) + ) + }; + case r1l2 { + balLeft( + l1, + x1, + #black(r1l2, x2, r2) + ) + } + } + } + } + }; + + public func delete(s : Set, compare : (T, T) -> O.Order, x : T) : Set { + var changed : Bool = false; + func delNode(left : Tree, x1 : T, right : Tree) : Tree { + switch (compare(x, x1)) { + case (#less) { + let newLeft = del left; + switch left { + case (#black(_, _, _)) { + balLeft(newLeft, x1, right) + }; + case _ { + #red(newLeft, x1, right) + } + } + }; + case (#greater) { + let newRight = del right; + switch right { + case (#black(_, _, _)) { + balRight(left, x1, newRight) + }; + case _ { + #red(left, x1, newRight) + } + } + }; + case (#equal) { + changed := true; + append(left, right) + } + } + }; + func del(tree : Tree) : Tree { + switch tree { + case (#black(left, x1, right)) { + delNode(left, x1, right) + }; + case (#red(left, x1, right)) { + delNode(left, x1, right) + }; + case (#leaf) { + tree + } + } + }; + let newRoot = switch (del(s.root)) { + case (#red(left, x1, right)) { + #black(left, x1, right) + }; + case other { other } + }; + { root = newRoot; + size = if changed { s.size -1 } else { s.size } } + }; + + // check binary search tree order of elements and black depth invariant of the RB-tree + public func validate(s : Set, comp : (T, T) -> O.Order) { + ignore blackDepth(s.root, comp) + }; + + func blackDepth(node : Tree, comp : (T, T) -> O.Order) : Nat { + func checkNode(left : Tree, x1 : T, right : Tree) : Nat { + checkElem(left, func(x: T) : Bool { comp(x, x1) == #less }); + checkElem(right, func(x: T) : Bool { comp(x, x1) == #greater }); + let leftBlacks = blackDepth(left, comp); + let rightBlacks = blackDepth(right, comp); + assert (leftBlacks == rightBlacks); + leftBlacks + }; + switch node { + case (#leaf) 0; + case (#red(left, x1, right)) { + assert (not isRed(left)); + assert (not isRed(right)); + checkNode(left, x1, right) + }; + case (#black(left, x1, right)) { + checkNode(left, x1, right) + 1 + } + } + }; + + func isRed(node : Tree) : Bool { + switch node { + case (#red(_, _, _)) true; + case _ false + } + }; + + func checkElem(node : Tree, isValid : T -> Bool) { + switch node { + case (#leaf) {}; + case (#black(_, elem, _)) { + assert (isValid(elem)) + }; + case (#red(_, elem, _)) { + assert (isValid(elem)) + } + } + } + }; + + /// Create `OrderedSet.Operations` object capturing element type `T` and `compare` function. + /// It is an alias for the `Operations` constructor. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/OrderedSet"; + /// import Nat "mo:base/Nat"; + /// + /// actor { + /// let natSet = Set.Make(Nat.compare); + /// stable var set : Set.Set = natSet.empty(); + /// }; + /// ``` + public let Make : (compare : (T, T) -> O.Order) -> Operations = Operations +} diff --git a/test/OrderedMap.prop.test.mo b/test/OrderedMap.prop.test.mo new file mode 100644 index 00000000..f3c698eb --- /dev/null +++ b/test/OrderedMap.prop.test.mo @@ -0,0 +1,277 @@ +// @testmode wasi + +import Map "../src/OrderedMap"; +import Nat "../src/Nat"; +import Iter "../src/Iter"; +import Debug "../src/Debug"; +import Array "../src/Array"; +import Option "../src/Option"; + +import Suite "mo:matchers/Suite"; +import T "mo:matchers/Testable"; +import M "mo:matchers/Matchers"; + +import Random2 "mo:base/Random"; + +let { run; test; suite } = Suite; + +let entryTestable = T.tuple2Testable(T.natTestable, T.textTestable); + +class MapMatcher(expected : Map.Map) : M.Matcher> { + public func describeMismatch(actual : Map.Map, _description : M.Description) { + Debug.print(debug_show (Iter.toArray(natMap.entries(actual))) # " should be " # debug_show (Iter.toArray(natMap.entries(expected)))) + }; + + public func matches(actual : Map.Map) : Bool { + Iter.toArray(natMap.entries(actual)) == Iter.toArray(natMap.entries(expected)) + } +}; + +object Random { + var number = 4711; + public func next() : Nat { + number := (15485863 * number + 5) % 15485867; + number + }; + + public func nextNat(range: (Nat, Nat)): Nat { + let n = next(); + let v = n % (range.1 - range.0 + 1) + range.0; + v + }; + + public func nextEntries(range: (Nat, Nat), size: Nat): [(Nat, Text)] { + Array.tabulate<(Nat, Text)>(size, func(_ix) { + let key = nextNat(range); (key, debug_show(key)) } ) + } +}; + +let natMap = Map.Make(Nat.compare); + +func mapGen(samples_number: Nat, size: Nat, range: (Nat, Nat)): Iter.Iter> { + object { + var n = 0; + public func next(): ?Map.Map { + n += 1; + if (n > samples_number) { + null + } else { + ?natMap.fromIter(Random.nextEntries(range, size).vals()) + } + } + } +}; + + +func run_all_props(range: (Nat, Nat), size: Nat, map_samples: Nat, query_samples: Nat) { + func prop(name: Text, f: Map.Map -> Bool): Suite.Suite { + var error_msg: Text = ""; + test(name, do { + var error = true; + label stop for(map in mapGen(map_samples, size, range)) { + if (not f(map)) { + error_msg := "Property \"" # name # "\" failed\n"; + error_msg #= "\n m: " # debug_show(Iter.toArray(natMap.entries(map))); + break stop; + } + }; + error_msg + }, M.describedAs(error_msg, M.equals(T.text("")))) + }; + func prop_with_key(name: Text, f: (Map.Map, Nat) -> Bool): Suite.Suite { + var error_msg: Text = ""; + test(name, do { + label stop for(map in mapGen(map_samples, size, range)) { + for (_query_ix in Iter.range(0, query_samples-1)) { + let key = Random.nextNat(range); + if (not f(map, key)) { + error_msg #= "Property \"" # name # "\" failed"; + error_msg #= "\n m: " # debug_show(Iter.toArray(natMap.entries(map))); + error_msg #= "\n k: " # debug_show(key); + break stop; + } + } + }; + error_msg + }, M.describedAs(error_msg, M.equals(T.text("")))) + }; + run( + suite("Property tests", + [ + suite("empty", [ + test("get(empty(), k) == null", label res : Bool { + for (_query_ix in Iter.range(0, query_samples-1)) { + let k = Random.nextNat(range); + if(natMap.get(natMap.empty(), k) != null) + break res(false); + }; + true; + }, M.equals(T.bool(true))) + ]), + + suite("get & put", [ + prop_with_key("get(put(m, k, v), k) == ?v", func (m, k) { + natMap.get(natMap.put(m, k, "v"), k) == ?"v" + }), + prop_with_key("get(put(put(m, k, v1), k, v2), k) == ?v2", func (m, k) { + let (v1, v2) = ("V1", "V2"); + natMap.get(natMap.put(natMap.put(m, k, v1), k, v2), k) == v2 + }), + ]), + + suite("replace", [ + prop_with_key("replace(m, k, v).0 == put(m, k, v)", func (m, k) { + natMap.replace(m, k, "v").0 == natMap.put(m, k, "v") + }), + prop_with_key("replace(put(m, k, v1), k, v2).1 == ?v1", func (m, k) { + natMap.replace(natMap.put(m, k, "v1"), k, "v2").1 == ?"v1" + }), + prop_with_key("get(m, k) == null ==> replace(m, k, v).1 == null", func (m, k) { + if (natMap.get(m, k) == null) { + natMap.replace(m, k, "v").1 == null + } else { true } + }), + ]), + + suite("delete", [ + prop_with_key("get(m, k) == null ==> delete(m, k) == m", func (m, k) { + if (natMap.get(m, k) == null) { + MapMatcher(m).matches(natMap.delete(m, k)) + } else { true } + }), + prop_with_key("delete(put(m, k, v), k) == m", func (m, k) { + if (natMap.get(m, k) == null) { + MapMatcher(m).matches(natMap.delete(natMap.put(m, k, "v"), k)) + } else { true } + }), + prop_with_key("delete(delete(m, k), k)) == delete(m, k)", func (m, k) { + let m1 = natMap.delete(natMap.delete(m, k), k); + let m2 = natMap.delete(m, k); + MapMatcher(m2).matches(m1) + }) + ]), + + suite("remove", [ + prop_with_key("remove(m, k).0 == delete(m, k)", func (m, k) { + let m1 = natMap.remove(m, k).0; + let m2 = natMap.delete(m, k); + MapMatcher(m2).matches(m1) + }), + prop_with_key("remove(put(m, k, v), k).1 == ?v", func (m, k) { + natMap.remove(natMap.put(m, k, "v"), k).1 == ?"v" + }), + prop_with_key("remove(remove(m, k).0, k).1 == null", func (m, k) { + natMap.remove(natMap.remove(m, k).0, k).1 == null + }), + prop_with_key("put(remove(m, k).0, k, remove(m, k).1) == m", func (m, k) { + if (natMap.get(m, k) != null) { + MapMatcher(m).matches(natMap.put(natMap.remove(m, k).0, k, Option.get(natMap.remove(m, k).1, ""))) + } else { true } + }) + ]), + + suite("size", [ + prop_with_key("size(put(m, k, v)) == size(m) + int(get(m, k) == null)", func (m, k) { + natMap.size(natMap.put(m, k, "v")) == natMap.size(m) + (if (natMap.get(m, k) == null) {1} else {0}) + }), + prop_with_key("size(delete(m, k)) + int(get(m, k) != null) == size(m)", func (m, k) { + natMap.size(natMap.delete(m, k)) + (if (natMap.get(m, k) != null) {1} else {0}) == natMap.size(m) + }) + ]), + + prop("search tree invariant", func (m) { + natMap.validate(m); + true + }), + + suite("keys,vals,entries,entriesRe", [ + prop("fromIter(entries(m)) == m", func (m) { + MapMatcher(m).matches(natMap.fromIter(natMap.entries(m))) + }), + prop("fromIter(entriesRev(m)) == m", func (m) { + MapMatcher(m).matches(natMap.fromIter(natMap.entriesRev(m))) + }), + prop("entries(m) = zip(key(m), vals(m))", func (m) { + let k = natMap.keys(m); + let v = natMap.vals(m); + for (e in natMap.entries(m)) { + if (e.0 != k.next() or e.1 != v.next()) + return false; + }; + return true; + }), + prop("Array.fromIter(entries(m)) == Array.fromIter(entriesRev(m)).reverse()", func (m) { + let a = Iter.toArray(natMap.entries(m)); + let b = Array.reverse(Iter.toArray(natMap.entriesRev(m))); + M.equals(T.array<(Nat, Text)>(entryTestable, a)).matches(b) + }), + ]), + + suite("mapFilter", [ + prop_with_key("get(mapFilter(m, (!=k)), k) == null", func (m, k) { + natMap.get(natMap.mapFilter(m, + func (ki, vi) { if (ki != k) {?vi} else {null}}), k) == null + }), + prop_with_key("get(mapFilter(put(m, k, v), (==k)), k) == ?v", func (m, k) { + natMap.get(natMap.mapFilter(natMap.put(m, k, "v"), + func (ki, vi) { if (ki == k) {?vi} else {null}}), k) == ?"v" + }) + ]), + + suite("map", [ + prop("map(m, id) == m", func (m) { + MapMatcher(m).matches(natMap.map(m, func (k, v) {v})) + }) + ]), + + suite("folds", [ + prop("foldLeft as entries()", func (m) { + let it = natMap.entries(m); + natMap.foldLeft(m, true, func (acc, k, v) {acc and it.next() == ?(k, v)}) + }), + prop("foldRight as entriesRev()", func(m) { + let it = natMap.entriesRev(m); + natMap.foldRight(m, true, func (k, v, acc) {acc and it.next() == ?(k, v)}) + }) + ]), + + suite("all/some", [ + prop("all through fold", func(m) { + let pred = func(k: Nat, v: Text): Bool = (k <= range.1 - 2 and range.0 + 2 <= k); + natMap.all(m, pred) == natMap.foldLeft(m, true, func (acc, k, v) {acc and pred(k, v)}) + }), + prop("some through fold", func(m) { + let pred = func(k: Nat, v: Text): Bool = (k >= range.1 - 1 or range.0 + 1 >= k); + natMap.some(m, pred) == natMap.foldLeft(m, false, func (acc, k, v) {acc or pred(k, v)}) + }), + + prop("forall k, v in map, v == show_debug(k)", func(m) { + natMap.all(m, func (k: Nat, v: Text): Bool = (v == debug_show(k))) + }), + ]), + + suite("contains", [ + prop_with_key("contains(m, k) == (get(m, k) != null)", func (m, k) { + natMap.contains(m, k) == (Option.isSome(natMap.get(m, k))) + }), + ]), + + suite("minEntry/maxEntry", [ + prop("max through fold", func (m) { + let expected = natMap.foldLeft(m, null: ?(Nat, Text), func (_, k, v) = ?(k, v) ); + M.equals(T.optional(entryTestable, expected)).matches(natMap.maxEntry(m)); + }), + + prop("min through fold", func (m) { + let expected = natMap.foldRight(m, null: ?(Nat, Text), func (k, v, _) = ?(k, v) ); + M.equals(T.optional(entryTestable, expected)).matches(natMap.minEntry(m)); + }) + ]), + ])) +}; + +run_all_props((1, 3), 0, 1, 10); +run_all_props((1, 5), 5, 100, 100); +run_all_props((1, 10), 10, 100, 100); +run_all_props((1, 100), 20, 100, 100); +run_all_props((1, 1000), 100, 100, 100); diff --git a/test/OrderedMap.test.mo b/test/OrderedMap.test.mo new file mode 100644 index 00000000..7c509af2 --- /dev/null +++ b/test/OrderedMap.test.mo @@ -0,0 +1,574 @@ +// @testmode wasi + +import Map "../src/OrderedMap"; +import Nat "../src/Nat"; +import Iter "../src/Iter"; +import Debug "../src/Debug"; +import Array "../src/Array"; + +import Suite "mo:matchers/Suite"; +import T "mo:matchers/Testable"; +import M "mo:matchers/Matchers"; + +let { run; test; suite } = Suite; + +let entryTestable = T.tuple2Testable(T.natTestable, T.textTestable); + +let natMap = Map.Make(Nat.compare); + +class MapMatcher(expected : [(Nat, Text)]) : M.Matcher> { + public func describeMismatch(actual : Map.Map, _description : M.Description) { + Debug.print(debug_show (Iter.toArray(natMap.entries(actual))) # " should be " # debug_show (expected)) + }; + + public func matches(actual : Map.Map) : Bool { + Iter.toArray(natMap.entries(actual)) == expected + } +}; + +func checkMap(m: Map.Map) { natMap.validate(m); }; + +func insert(rbTree : Map.Map, key : Nat) : Map.Map { + let updatedTree = natMap.put(rbTree, key, debug_show (key)); + checkMap(updatedTree); + updatedTree +}; + +func getAll(rbTree : Map.Map, keys : [Nat]) { + for (key in keys.vals()) { + let value = natMap.get(rbTree, key); + assert (value == ?debug_show (key)) + } +}; + +func clear(initialRbMap : Map.Map) : Map.Map { + var rbMap = initialRbMap; + for ((key, value) in natMap.entries(initialRbMap)) { + // stable iteration + assert (value == debug_show (key)); + let (newMap, result) = natMap.remove(rbMap, key); + rbMap := newMap; + assert (result == ?debug_show (key)); + checkMap(rbMap) + }; + rbMap +}; + +func expectedEntries(keys : [Nat]) : [(Nat, Text)] { + Array.tabulate<(Nat, Text)>(keys.size(), func(index) { (keys[index], debug_show (keys[index])) }) +}; + +func concatenateKeys(key : Nat, value : Text, accum : Text) : Text { + accum # debug_show(key) +}; + +func concatenateKeys2(accum : Text, key : Nat, value : Text) : Text { + accum # debug_show(key) +}; + +func concatenateValues(key : Nat, value : Text, accum : Text) : Text { + accum # value +}; + +func concatenateValues2(accum: Text, key : Nat, value : Text) : Text { + accum # value +}; + +func multiplyKeyAndConcat(key : Nat, value : Text) : Text { + debug_show(key * 2) # value +}; + +func ifKeyLessThan(threshold : Nat, f : (Nat, Text) -> Text) : (Nat, Text) -> ?Text + = func (key, value) { + if(key < threshold) + ?f(key, value) + else null + }; + +/* --------------------------------------- */ + +var buildTestMap = func() : Map.Map { + natMap.empty() +}; + +run( + suite( + "empty", + [ + test( + "size", + natMap.size(buildTestMap()), + M.equals(T.nat(0)) + ), + test( + "entries", + Iter.toArray(natMap.entries(buildTestMap())), + M.equals(T.array<(Nat, Text)>(entryTestable, [])) + ), + test( + "entriesRev", + Iter.toArray(natMap.entriesRev(buildTestMap())), + M.equals(T.array<(Nat, Text)>(entryTestable, [])) + ), + test( + "keys", + Iter.toArray(natMap.keys(buildTestMap())), + M.equals(T.array(T.natTestable, [])) + ), + test( + "vals", + Iter.toArray(natMap.vals(buildTestMap())), + M.equals(T.array(T.textTestable, [])) + ), + test( + "empty from iter", + natMap.fromIter(Iter.fromArray([])), + MapMatcher([]) + ), + test( + "get absent", + natMap.get(buildTestMap(), 0), + M.equals(T.optional(T.textTestable, null : ?Text)) + ), + test( + "contains absent", + natMap.contains(buildTestMap(), 0), + M.equals(T.bool(false)) + ), + test( + "maxEntry", + natMap.maxEntry(buildTestMap()), + M.equals(T.optional(entryTestable, null: ?(Nat, Text))) + ), + test( + "minEntry", + natMap.minEntry(buildTestMap()), + M.equals(T.optional(entryTestable, null: ?(Nat, Text))) + ), + test( + "remove absent", + natMap.remove(buildTestMap(), 0).1, + M.equals(T.optional(T.textTestable, null : ?Text)) + ), + test( + "replace absent/no value", + natMap.replace(buildTestMap(), 0, "Test").1, + M.equals(T.optional(T.textTestable, null : ?Text)) + ), + test( + "replace absent/key appeared", + natMap.replace(buildTestMap(), 0, "Test").0, + MapMatcher([(0, "Test")]) + ), + test( + "empty right fold keys", + natMap.foldRight(buildTestMap(), "", concatenateKeys), + M.equals(T.text("")) + ), + test( + "empty left fold keys", + natMap.foldLeft(buildTestMap(), "", concatenateKeys2), + M.equals(T.text("")) + ), + test( + "empty right fold values", + natMap.foldRight(buildTestMap(), "", concatenateValues), + M.equals(T.text("")) + ), + test( + "empty left fold values", + natMap.foldLeft(buildTestMap(), "", concatenateValues2), + M.equals(T.text("")) + ), + test( + "traverse empty map", + natMap.map(buildTestMap(), multiplyKeyAndConcat), + MapMatcher([]) + ), + test( + "empty map filter", + natMap.mapFilter(buildTestMap(), ifKeyLessThan(0, multiplyKeyAndConcat)), + MapMatcher([]) + ), + test( + "empty all", + natMap.all(buildTestMap(), func (k, v) = false), + M.equals(T.bool(true)) + ), + test( + "empty some", + natMap.some(buildTestMap(), func (k, v) = true), + M.equals(T.bool(false)) + ) + ] + ) +); + +/* --------------------------------------- */ + +buildTestMap := func() : Map.Map { + insert(natMap.empty(), 0); +}; + +var expected = expectedEntries([0]); + +run( + suite( + "single root", + [ + test( + "size", + natMap.size(buildTestMap()), + M.equals(T.nat(1)) + ), + test( + "entries", + Iter.toArray(natMap.entries(buildTestMap())), + M.equals(T.array<(Nat, Text)>(entryTestable, expected)) + ), + test( + "entriesRev", + Iter.toArray(natMap.entriesRev(buildTestMap())), + M.equals(T.array<(Nat, Text)>(entryTestable, expected)) + ), + test( + "keys", + Iter.toArray(natMap.keys(buildTestMap())), + M.equals(T.array(T.natTestable, [0])) + ), + test( + "vals", + Iter.toArray(natMap.vals(buildTestMap())), + M.equals(T.array(T.textTestable, ["0"])) + ), + test( + "from iter", + natMap.fromIter(Iter.fromArray(expected)), + MapMatcher(expected) + ), + test( + "get", + natMap.get(buildTestMap(), 0), + M.equals(T.optional(T.textTestable, ?"0")) + ), + test( + "contains", + natMap.contains(buildTestMap(), 0), + M.equals(T.bool(true)) + ), + test( + "maxEntry", + natMap.maxEntry(buildTestMap()), + M.equals(T.optional(entryTestable, ?(0, "0"))) + ), + test( + "minEntry", + natMap.minEntry(buildTestMap()), + M.equals(T.optional(entryTestable, ?(0, "0"))) + ), + test( + "replace function result", + natMap.replace(buildTestMap(), 0, "TEST").1, + M.equals(T.optional(T.textTestable, ?"0")) + ), + test( + "replace map result", + do { + let rbMap = buildTestMap(); + natMap.replace(rbMap, 0, "TEST").0 + }, + MapMatcher([(0, "TEST")]) + ), + test( + "remove function result", + natMap.remove(buildTestMap(), 0).1, + M.equals(T.optional(T.textTestable, ?"0")) + ), + test( + "remove map result", + do { + var rbMap = buildTestMap(); + rbMap := natMap.remove(rbMap, 0).0; + checkMap(rbMap); + rbMap + }, + MapMatcher([]) + ), + test( + "right fold keys", + natMap.foldRight(buildTestMap(), "", concatenateKeys), + M.equals(T.text("0")) + ), + test( + "left fold keys", + natMap.foldLeft(buildTestMap(), "", concatenateKeys2), + M.equals(T.text("0")) + ), + test( + "right fold values", + natMap.foldRight(buildTestMap(), "", concatenateValues), + M.equals(T.text("0")) + ), + test( + "left fold values", + natMap.foldLeft(buildTestMap(), "", concatenateValues2), + M.equals(T.text("0")) + ), + test( + "traverse map", + natMap.map(buildTestMap(), multiplyKeyAndConcat), + MapMatcher([(0, "00")]) + ), + test( + "map filter/filter all", + natMap.mapFilter(buildTestMap(), ifKeyLessThan(0, multiplyKeyAndConcat)), + MapMatcher([]) + ), + test( + "map filter/no filer", + natMap.mapFilter(buildTestMap(), ifKeyLessThan(1, multiplyKeyAndConcat)), + MapMatcher([(0, "00")]) + ), + test( + "all", + natMap.all(buildTestMap(), func (k, v) = (k == 0)), + M.equals(T.bool(true)) + ), + test( + "some", + natMap.some(buildTestMap(), func (k, v) = (k == 0)), + M.equals(T.bool(true)) + ) + ] + ) +); + +/* --------------------------------------- */ + +expected := expectedEntries([0, 1, 2]); + +func rebalanceTests(buildTestMap : () -> Map.Map) : [Suite.Suite] = + [ + test( + "size", + natMap.size(buildTestMap()), + M.equals(T.nat(3)) + ), + test( + "map match", + buildTestMap(), + MapMatcher(expected) + ), + test( + "entries", + Iter.toArray(natMap.entries(buildTestMap())), + M.equals(T.array<(Nat, Text)>(entryTestable, expected)) + ), + test( + "entriesRev", + Iter.toArray(natMap.entriesRev(buildTestMap())), + M.equals(T.array<(Nat, Text)>(entryTestable, Array.reverse(expected))) + ), + test( + "keys", + Iter.toArray(natMap.keys(buildTestMap())), + M.equals(T.array(T.natTestable, [0, 1, 2])) + ), + test( + "vals", + Iter.toArray(natMap.vals(buildTestMap())), + M.equals(T.array(T.textTestable, ["0", "1", "2"])) + ), + test( + "from iter", + natMap.fromIter(Iter.fromArray(expected)), + MapMatcher(expected) + ), + test( + "get all", + do { + let rbMap = buildTestMap(); + getAll(rbMap, [0, 1, 2]); + rbMap + }, + MapMatcher(expected) + ), + test( + "contains", + Array.tabulate(4, func (k: Nat) = (natMap.contains(buildTestMap(), k))), + M.equals(T.array(T.boolTestable, [true, true, true, false])) + ), + test( + "maxEntry", + natMap.maxEntry(buildTestMap()), + M.equals(T.optional(entryTestable, ?(2, "2"))) + ), + test( + "minEntry", + natMap.minEntry(buildTestMap()), + M.equals(T.optional(entryTestable, ?(0, "0"))) + ), + test( + "clear", + clear(buildTestMap()), + MapMatcher([]) + ), + test( + "right fold keys", + natMap.foldRight(buildTestMap(), "", concatenateKeys), + M.equals(T.text("210")) + ), + test( + "left fold keys", + natMap.foldLeft(buildTestMap(), "", concatenateKeys2), + M.equals(T.text("012")) + ), + test( + "right fold values", + natMap.foldRight(buildTestMap(), "", concatenateValues), + M.equals(T.text("210")) + ), + test( + "left fold values", + natMap.foldLeft(buildTestMap(), "", concatenateValues2), + M.equals(T.text("012")) + ), + test( + "traverse map", + natMap.map(buildTestMap(), multiplyKeyAndConcat), + MapMatcher([(0, "00"), (1, "21"), (2, "42")]) + ), + test( + "map filter/filter all", + natMap.mapFilter(buildTestMap(), ifKeyLessThan(0, multiplyKeyAndConcat)), + MapMatcher([]) + ), + test( + "map filter/filter one", + natMap.mapFilter(buildTestMap(), ifKeyLessThan(1, multiplyKeyAndConcat)), + MapMatcher([(0, "00")]) + ), + test( + "map filter/no filer", + natMap.mapFilter(buildTestMap(), ifKeyLessThan(3, multiplyKeyAndConcat)), + MapMatcher([(0, "00"), (1, "21"), (2, "42")]) + ), + test( + "all true", + natMap.all(buildTestMap(), func (k, v) = (k >= 0)), + M.equals(T.bool(true)) + ), + test( + "all false", + natMap.all(buildTestMap(), func (k, v) = (k > 0)), + M.equals(T.bool(false)) + ), + test( + "some true", + natMap.some(buildTestMap(), func (k, v) = (k >= 2)), + M.equals(T.bool(true)) + ), + test( + "some false", + natMap.some(buildTestMap(), func (k, v) = (k > 2)), + M.equals(T.bool(false)) + ) + ]; + +buildTestMap := func() : Map.Map { + var rbMap = natMap.empty() : Map.Map; + rbMap := insert(rbMap, 2); + rbMap := insert(rbMap, 1); + rbMap := insert(rbMap, 0); + rbMap +}; + +run(suite("rebalance left, left", rebalanceTests(buildTestMap))); + +/* --------------------------------------- */ + +buildTestMap := func() : Map.Map { + var rbMap = natMap.empty() : Map.Map; + rbMap := insert(rbMap, 2); + rbMap := insert(rbMap, 0); + rbMap := insert(rbMap, 1); + rbMap +}; + +run(suite("rebalance left, right", rebalanceTests(buildTestMap))); + +/* --------------------------------------- */ + +buildTestMap := func() : Map.Map { + var rbMap = natMap.empty() : Map.Map; + rbMap := insert(rbMap, 0); + rbMap := insert(rbMap, 2); + rbMap := insert(rbMap, 1); + rbMap +}; + +run(suite("rebalance right, left", rebalanceTests(buildTestMap))); + +/* --------------------------------------- */ + +buildTestMap := func() : Map.Map { + var rbMap = natMap.empty() : Map.Map; + rbMap := insert(rbMap, 0); + rbMap := insert(rbMap, 1); + rbMap := insert(rbMap, 2); + rbMap +}; + +run(suite("rebalance right, right", rebalanceTests(buildTestMap))); + +/* --------------------------------------- */ + +run( + suite( + "repeated operations", + [ + test( + "repeated insert", + do { + var rbMap = buildTestMap(); + assert (natMap.get(rbMap, 1) == ?"1"); + rbMap := natMap.put(rbMap, 1, "TEST-1"); + natMap.get(rbMap, 1) + }, + M.equals(T.optional(T.textTestable, ?"TEST-1")) + ), + test( + "repeated replace", + do { + let rbMap0 = buildTestMap(); + let (rbMap1, firstResult) = natMap.replace(rbMap0, 1, "TEST-1"); + assert (firstResult == ?"1"); + let (rbMap2, secondResult) = natMap.replace(rbMap1, 1, "1"); + assert (secondResult == ?"TEST-1"); + rbMap2 + }, + MapMatcher(expected) + ), + test( + "repeated remove", + do { + var rbMap0 = buildTestMap(); + let (rbMap1, result) = natMap.remove(rbMap0, 1); + assert (result == ?"1"); + checkMap(rbMap1); + natMap.remove(rbMap1, 1).1 + }, + M.equals(T.optional(T.textTestable, null : ?Text)) + ), + test( + "repeated delete", + do { + var rbMap = buildTestMap(); + rbMap := natMap.delete(rbMap, 1); + natMap.delete(rbMap, 1) + }, + MapMatcher(expectedEntries([0, 2])) + ) + ] + ) +); diff --git a/test/OrderedSet.prop.test.mo b/test/OrderedSet.prop.test.mo new file mode 100644 index 00000000..4e92db06 --- /dev/null +++ b/test/OrderedSet.prop.test.mo @@ -0,0 +1,334 @@ +// @testmode wasi + +import Set "../src/OrderedSet"; +import Nat "../src/Nat"; +import Iter "../src/Iter"; +import Debug "../src/Debug"; +import Array "../src/Array"; + +import Suite "mo:matchers/Suite"; +import T "mo:matchers/Testable"; +import M "mo:matchers/Matchers"; + +let { run; test; suite } = Suite; + +let natSet = Set.Make(Nat.compare); + +class SetMatcher(expected : Set.Set) : M.Matcher> { + public func describeMismatch(actual : Set.Set, _description : M.Description) { + Debug.print(debug_show (Iter.toArray(natSet.vals(actual))) # " should be " # debug_show (Iter.toArray(natSet.vals(expected)))) + }; + + public func matches(actual : Set.Set) : Bool { + natSet.equals(actual, expected) + } +}; + +object Random { + var number = 4711; + public func next() : Nat { + number := (15485863 * number + 5) % 15485867; + number + }; + + public func nextNat(range: (Nat, Nat)): Nat { + let n = next(); + let v = n % (range.1 - range.0 + 1) + range.0; + v + }; + + public func nextEntries(range: (Nat, Nat), size: Nat): [Nat] { + Array.tabulate(size, func(_ix) { + let key = nextNat(range); key }) + } +}; + +func setGenN(samples_number: Nat, size: Nat, range: (Nat, Nat), chunkSize: Nat): Iter.Iter<[Set.Set]> { + object { + var n = 0; + public func next(): ?([Set.Set]) { + n += 1; + if (n > samples_number) { + null + } else { + ?Array.tabulate>(chunkSize, func _i = natSet.fromIter(Random.nextEntries(range, size).vals())) + } + } + } +}; + +func run_all_props(range: (Nat, Nat), size: Nat, set_samples: Nat, query_samples: Nat) { + func prop(name: Text, f: Set.Set -> Bool): Suite.Suite { + var error_msg: Text = ""; + test(name, do { + var error = true; + label stop for(sets in setGenN(set_samples, size, range, 1)) { + if (not f(sets[0])) { + error_msg := "Property \"" # name # "\" failed\n"; + error_msg #= "\n s: " # debug_show(Iter.toArray(natSet.vals(sets[0]))); + break stop; + } + }; + error_msg + }, M.describedAs(error_msg, M.equals(T.text("")))) + }; + + func prop2(name: Text, f: (Set.Set, Set.Set) -> Bool): Suite.Suite { + var error_msg: Text = ""; + test(name, do { + var error = true; + label stop for(sets in setGenN(set_samples, size, range, 2)) { + if (not f(sets[0], sets[1])) { + error_msg := "Property \"" # name # "\" failed\n"; + error_msg #= "\n s1: " # debug_show(Iter.toArray(natSet.vals(sets[0]))); + error_msg #= "\n s2: " # debug_show(Iter.toArray(natSet.vals(sets[1]))); + break stop; + } + }; + error_msg + }, M.describedAs(error_msg, M.equals(T.text("")))) + }; + + func prop3(name: Text, f: (Set.Set, Set.Set, Set.Set) -> Bool): Suite.Suite { + var error_msg: Text = ""; + test(name, do { + var error = true; + label stop for(sets in setGenN(set_samples, size, range, 3)) { + if (not f(sets[0], sets[1], sets[2])) { + error_msg := "Property \"" # name # "\" failed\n"; + error_msg #= "\n s1: " # debug_show(Iter.toArray(natSet.vals(sets[0]))); + error_msg #= "\n s2: " # debug_show(Iter.toArray(natSet.vals(sets[1]))); + error_msg #= "\n s3: " # debug_show(Iter.toArray(natSet.vals(sets[2]))); + break stop; + } + }; + error_msg + }, M.describedAs(error_msg, M.equals(T.text("")))) + }; + + func prop_with_elem(name: Text, f: (Set.Set, Nat) -> Bool): Suite.Suite { + var error_msg: Text = ""; + test(name, do { + label stop for(sets in setGenN(set_samples, size, range, 1)) { + for (_query_ix in Iter.range(0, query_samples-1)) { + let key = Random.nextNat(range); + if (not f(sets[0], key)) { + error_msg #= "Property \"" # name # "\" failed"; + error_msg #= "\n s: " # debug_show(Iter.toArray(natSet.vals(sets[0]))); + error_msg #= "\n e: " # debug_show(key); + break stop; + } + } + }; + error_msg + }, M.describedAs(error_msg, M.equals(T.text("")))) + }; + + run( + suite("Property tests", + [ + suite("empty", [ + test("not contains(empty(), e)", label res : Bool { + for (_query_ix in Iter.range(0, query_samples-1)) { + let elem = Random.nextNat(range); + if(natSet.contains(natSet.empty(), elem)) + break res(false); + }; + true; + }, M.equals(T.bool(true))) + ]), + + suite("contains & put", [ + prop_with_elem("contains(put(s, e), e)", func (s, e) { + natSet.contains(natSet.put(s, e), e) + }), + prop_with_elem("put(put(s, e), e) == put(s, e)", func (s, e) { + let s1 = natSet.put(s, e); + let s2 = natSet.put(natSet.put(s, e), e); + SetMatcher(s1).matches(s2) + }), + ]), + + suite("folds", [ + prop("foldLeft as vals()", func (m) { + let it = natSet.vals(m); + natSet.foldLeft(m, true, func (acc, v) {acc and it.next() == ?v}) + }), + prop("foldRight as valsRev()", func(m) { + let it = natSet.valsRev(m); + natSet.foldRight(m, true, func (v, acc) {acc and it.next() == ?v}) + }) + ]), + + suite("min/max", [ + prop("max through fold", func (s) { + let expected = natSet.foldLeft(s, null: ?Nat, func (_, v) = ?v ); + M.equals(T.optional(T.natTestable, expected)).matches(natSet.max(s)); + }), + prop("min through fold", func (s) { + let expected = natSet.foldRight(s, null: ?Nat, func (v, _) = ?v ); + M.equals(T.optional(T.natTestable, expected)).matches(natSet.min(s)); + }), + ]), + + suite("all/some", [ + prop("all through fold", func(s) { + let pred = func(k: Nat): Bool = (k <= range.1 - 2 and range.0 + 2 <= k); + natSet.all(s, pred) == natSet.foldLeft(s, true, func (acc, v) {acc and pred(v)}) + }), + prop("some through fold", func(s) { + let pred = func(k: Nat): Bool = (k >= range.1 - 1 or range.0 + 1 >= k); + natSet.some(s, pred) == natSet.foldLeft(s, false, func (acc, v) {acc or pred(v)}) + }), + ]), + + suite("delete", [ + prop_with_elem("not contains(s, e) ==> delete(s, e) == s", func (s, e) { + if (not natSet.contains(s, e)) { + SetMatcher(s).matches(natSet.delete(s, e)) + } else { true } + }), + prop_with_elem("delete(put(s, e), e) == s", func (s, e) { + if (not natSet.contains(s, e)) { + SetMatcher(s).matches(natSet.delete(natSet.put(s, e), e)) + } else { true } + }), + prop_with_elem("delete(delete(s, e), e)) == delete(s, e)", func (s, e) { + let s1 = natSet.delete(natSet.delete(s, e), e); + let s2 = natSet.delete(s, e); + SetMatcher(s2).matches(s1) + }) + ]), + + suite("size", [ + prop_with_elem("size(put(s, e)) == size(s) + int(not contains(s, e))", func (s, e) { + natSet.size(natSet.put(s, e)) == natSet.size(s) + (if (not natSet.contains(s, e)) {1} else {0}) + }), + prop_with_elem("size(delete(s, e)) + int(contains(s, e)) == size(s)", func (s, e) { + natSet.size(natSet.delete(s, e)) + (if (natSet.contains(s, e)) {1} else {0}) == natSet.size(s) + }) + ]), + + suite("vals/valsRev", [ + prop("fromIter(vals(s)) == s", func (s) { + SetMatcher(s).matches(natSet.fromIter(natSet.vals(s))) + }), + prop("fromIter(valsRev(s)) == s", func (s) { + SetMatcher(s).matches(natSet.fromIter(natSet.valsRev(s))) + }), + prop("toArray(vals(s)).reverse() == toArray(valsRev(s))", func (s) { + let a = Array.reverse(Iter.toArray(natSet.vals(s))); + let b = Iter.toArray(natSet.valsRev(s)); + M.equals(T.array(T.natTestable, a)).matches(b) + }), + ]), + + suite(("Internal"), [ + prop("search tree invariant", func (s) { + natSet.validate(s); + true + }) + ]), + + suite("mapFilter", [ + prop_with_elem("not contains(mapFilter(s, (!=e)), e)", func (s, e) { + not natSet.contains(natSet.mapFilter(s, + func (ei) { if (ei != e) {?ei} else {null}}), e) + }), + prop_with_elem("contains(mapFilter(put(s, e), (==e)), e)", func (s, e) { + natSet.contains(natSet.mapFilter(natSet.put(s, e), + func (ei) { if (ei == e) {?ei} else {null}}), e) + }) + ]), + + suite("map", [ + prop("map(s, id) == s", func (s) { + SetMatcher(s).matches(natSet.map(s, func (e) {e})) + }) + ]), + + suite("set operations", [ + prop("isSubset(s, s)", func (s) { + natSet.isSubset(s, s) + }), + prop("isSubset(empty(), s)", func (s) { + natSet.isSubset(natSet.empty(), s) + }), + prop_with_elem("isSubset(delete(s, e), s)", func (s, e) { + natSet.isSubset(natSet.delete(s, e), s) + }), + prop_with_elem("contains(s, e) ==> not isSubset(s, delete(s, e))", func (s, e) { + if (natSet.contains(s, e)) { + not natSet.isSubset(s, natSet.delete(s, e)) + } else { true } + }), + prop_with_elem("isSubset(s, put(s, e))", func (s, e) { + natSet.isSubset(s, natSet.put(s, e)) + }), + prop_with_elem("not contains(s, e) ==> not isSubset(put(s, e), s)", func (s, e) { + if (not natSet.contains(s, e)) { + not natSet.isSubset(natSet.put(s, e), s) + } else { true } + }), + prop("intersect(empty(), s) == empty()", func (s) { + SetMatcher(natSet.empty()).matches(natSet.intersect(natSet.empty(), s)) + }), + prop("intersect(s, empty()) == empty()", func (s) { + SetMatcher(natSet.empty()).matches(natSet.intersect(s, natSet.empty())) + }), + prop("union(s, empty()) == s", func (s) { + SetMatcher(s).matches(natSet.union(s, natSet.empty())) + }), + prop("union(empty(), s) == s", func (s) { + SetMatcher(s).matches(natSet.union(natSet.empty(), s)) + }), + prop("diff(empty(), s) == empty()", func (s) { + SetMatcher(natSet.empty()).matches(natSet.diff(natSet.empty(), s)) + }), + prop("diff(s, empty()) == s", func (s) { + SetMatcher(s).matches(natSet.diff(s, natSet.empty())) + }), + prop("intersect(s, s) == s", func (s) { + SetMatcher(s).matches(natSet.intersect(s, s)) + }), + prop("union(s, s) == s", func (s) { + SetMatcher(s).matches(natSet.union(s, s)) + }), + prop("diff(s, s) == empty()", func (s) { + SetMatcher(natSet.empty()).matches(natSet.diff(s, s)) + }), + prop2("intersect(s1, s2) == intersect(s2, s1)", func (s1, s2) { + SetMatcher(natSet.intersect(s1, s2)).matches(natSet.intersect(s2, s1)) + }), + prop2("union(s1, s2) == union(s2, s1)", func (s1, s2) { + SetMatcher(natSet.union(s1, s2)).matches(natSet.union(s2, s1)) + }), + prop2("isSubset(diff(s1, s2), s1)", func (s1, s2) { + natSet.isSubset(natSet.diff(s1, s2), s1) + }), + prop2("intersect(diff(s1, s2), s2) == empty()", func (s1, s2) { + SetMatcher(natSet.intersect(natSet.diff(s1, s2), s2)).matches(natSet.empty()) + }), + prop3("union(union(s1, s2), s3) == union(s1, union(s2, s3))", func (s1, s2, s3) { + SetMatcher(natSet.union(natSet.union(s1, s2), s3)).matches(natSet.union(s1, natSet.union(s2, s3))) + }), + prop3("intersect(intersect(s1, s2), s3) == intersect(s1, intersect(s2, s3))", func (s1, s2, s3) { + SetMatcher(natSet.intersect(natSet.intersect(s1, s2), s3)).matches(natSet.intersect(s1, natSet.intersect(s2, s3))) + }), + prop3("union(s1, intersect(s2, s3)) == intersect(union(s1, s2), union(s1, s3))", func (s1, s2, s3) { + SetMatcher(natSet.union(s1, natSet.intersect(s2, s3))).matches( + natSet.intersect(natSet.union(s1, s2), natSet.union(s1, s3))) + }), + prop3("intersect(s1, union(s2, s3)) == union(intersect(s1, s2), intersect(s1, s3))", func (s1, s2, s3) { + SetMatcher(natSet.intersect(s1, natSet.union(s2, s3))).matches( + natSet.union(natSet.intersect(s1, s2), natSet.intersect(s1, s3))) + }), + ]), + ])) +}; + +run_all_props((1, 3), 0, 1, 10); +run_all_props((1, 5), 5, 100, 100); +run_all_props((1, 10), 10, 100, 100); +run_all_props((1, 100), 20, 100, 100); +run_all_props((1, 1000), 100, 100, 100); diff --git a/test/OrderedSet.test.mo b/test/OrderedSet.test.mo new file mode 100644 index 00000000..eb2f1e4d --- /dev/null +++ b/test/OrderedSet.test.mo @@ -0,0 +1,600 @@ +// @testmode wasi + +import Set "../src/OrderedSet"; +import Array "../src/Array"; +import Nat "../src/Nat"; +import Iter "../src/Iter"; +import Debug "../src/Debug"; + +import Suite "mo:matchers/Suite"; +import T "mo:matchers/Testable"; +import M "mo:matchers/Matchers"; + +let { run; test; suite } = Suite; + +let entryTestable = T.natTestable; + +class SetMatcher(expected : [Nat]) : M.Matcher> { + public func describeMismatch(actual : Set.Set, _description : M.Description) { + Debug.print(debug_show (Iter.toArray(natSet.vals(actual))) # " should be " # debug_show (expected)) + }; + + public func matches(actual : Set.Set) : Bool { + Iter.toArray(natSet.vals(actual)) == expected + } +}; + +let natSet = Set.Make(Nat.compare); + +func insert(s : Set.Set, key : Nat) : Set.Set { + let updatedTree = natSet.put(s, key); + natSet.validate(updatedTree); + updatedTree +}; + +func concatenateKeys(key : Nat, accum : Text) : Text { + accum # debug_show(key) +}; + +func concatenateKeys2(accum : Text, key : Nat) : Text { + accum # debug_show(key) +}; + +func containsAll (rbSet : Set.Set, elems : [Nat]) { + for (elem in elems.vals()) { + assert (natSet.contains(rbSet, elem)) + } +}; + +func clear(initialRbSet : Set.Set) : Set.Set { + var rbSet = initialRbSet; + for (elem in natSet.vals(initialRbSet)) { + let newSet = natSet.delete(rbSet, elem); + rbSet := newSet; + natSet.validate(rbSet) + }; + rbSet +}; + +func add1(x : Nat) : Nat { x + 1 }; + +func ifElemLessThan(threshold : Nat, f : Nat -> Nat) : Nat -> ?Nat + = func (x) { + if(x < threshold) + ?f(x) + else null + }; + + +/* --------------------------------------- */ + +var buildTestSet = func() : Set.Set { + natSet.empty() +}; + +run( + suite( + "empty", + [ + test( + "size", + natSet.size(buildTestSet()), + M.equals(T.nat(0)) + ), + test( + "vals", + Iter.toArray(natSet.vals(buildTestSet())), + M.equals(T.array(entryTestable, [])) + ), + test( + "valsRev", + Iter.toArray(natSet.vals(buildTestSet())), + M.equals (T.array(entryTestable, [])) + ), + test( + "empty from iter", + natSet.fromIter(Iter.fromArray([])), + SetMatcher([]) + ), + test( + "contains absent", + natSet.contains(buildTestSet(), 0), + M.equals(T.bool(false)) + ), + test( + "empty right fold", + natSet.foldRight(buildTestSet(), "", concatenateKeys), + M.equals(T.text("")) + ), + test( + "empty left fold", + natSet.foldLeft(buildTestSet(), "", concatenateKeys2), + M.equals(T.text("")) + ), + test( + "traverse empty set", + natSet.map(buildTestSet(), add1), + SetMatcher([]) + ), + test( + "empty map filter", + natSet.mapFilter(buildTestSet(), ifElemLessThan(0, add1)), + SetMatcher([]) + ), + test( + "is empty", + natSet.isEmpty(buildTestSet()), + M.equals(T.bool(true)) + ), + test( + "max", + natSet.max(buildTestSet()), + M.equals(T.optional(entryTestable, null: ?Nat)) + ), + test( + "min", + natSet.min(buildTestSet()), + M.equals(T.optional(entryTestable, null: ?Nat)) + ) + ] + ) +); + +/* --------------------------------------- */ + +buildTestSet := func() : Set.Set { + insert(natSet.empty(), 0); +}; + +var expected = [0]; + +run( + suite( + "single root", + [ + test( + "size", + natSet.size(buildTestSet()), + M.equals(T.nat(1)) + ), + test( + "vals", + Iter.toArray(natSet.vals(buildTestSet())), + M.equals(T.array(entryTestable, expected)) + ), + test( + "valsRev", + Iter.toArray(natSet.valsRev(buildTestSet())), + M.equals(T.array(entryTestable, expected)) + ), + test( + "from iter", + natSet.fromIter(Iter.fromArray(expected)), + SetMatcher(expected) + ), + test( + "contains", + natSet.contains(buildTestSet(), 0), + M.equals(T.bool(true)) + ), + test( + "delete", + natSet.delete(buildTestSet(), 0), + SetMatcher([]) + ), + test( + "right fold", + natSet.foldRight(buildTestSet(), "", concatenateKeys), + M.equals(T.text("0")) + ), + test( + "left fold", + natSet.foldLeft(buildTestSet(), "", concatenateKeys2), + M.equals(T.text("0")) + ), + test( + "traverse set", + natSet.map(buildTestSet(), add1), + SetMatcher([1]) + ), + test( + "map filter/filter all", + natSet.mapFilter(buildTestSet(), ifElemLessThan(0, add1)), + SetMatcher([]) + ), + test( + "map filter/no filer", + natSet.mapFilter(buildTestSet(), ifElemLessThan(1, add1)), + SetMatcher([1]) + ), + test( + "is empty", + natSet.isEmpty(buildTestSet()), + M.equals(T.bool(false)) + ), + test( + "max", + natSet.max(buildTestSet()), + M.equals(T.optional(entryTestable, ?0)) + ), + test( + "min", + natSet.min(buildTestSet()), + M.equals(T.optional(entryTestable, ?0)) + ), + test( + "all", + natSet.all(buildTestSet(), func (k) = (k == 0)), + M.equals(T.bool(true)) + ), + test( + "some", + natSet.some(buildTestSet(), func (k) = (k == 0)), + M.equals(T.bool(true)) + ), + ] + ) +); + +/* --------------------------------------- */ + +expected := [0, 1, 2]; + +func rebalanceTests(buildTestSet : () -> Set.Set) : [Suite.Suite] = + [ + test( + "size", + natSet.size(buildTestSet()), + M.equals(T.nat(3)) + ), + test( + "Set match", + buildTestSet(), + SetMatcher(expected) + ), + test( + "vals", + Iter.toArray(natSet.vals(buildTestSet())), + M.equals(T.array(entryTestable, expected)) + ), + test( + "valsRev", + Array.reverse(Iter.toArray(natSet.valsRev(buildTestSet()))), + M.equals(T.array(entryTestable, expected)) + ), + test( + "from iter", + natSet.fromIter(Iter.fromArray(expected)), + SetMatcher(expected) + ), + test( + "contains all", + do { + let rbSet = buildTestSet(); + containsAll(rbSet, [0, 1, 2]); + rbSet + }, + SetMatcher(expected) + ), + test( + "clear", + clear(buildTestSet()), + SetMatcher([]) + ), + test( + "right fold", + natSet.foldRight(buildTestSet(), "", concatenateKeys), + M.equals(T.text("210")) + ), + test( + "left fold", + natSet.foldLeft(buildTestSet(), "", concatenateKeys2), + M.equals(T.text("012")) + ), + test( + "traverse set", + natSet.map(buildTestSet(), add1), + SetMatcher([1, 2, 3]) + ), + test( + "traverse set/reshape", + natSet.map(buildTestSet(), func (x : Nat) : Nat {5}), + SetMatcher([5]) + ), + test( + "map filter/filter all", + natSet.mapFilter(buildTestSet(), ifElemLessThan(0, add1)), + SetMatcher([]) + ), + test( + "map filter/filter one", + natSet.mapFilter(buildTestSet(), ifElemLessThan(1, add1)), + SetMatcher([1]) + ), + test( + "map filter/no filer", + natSet.mapFilter(buildTestSet(), ifElemLessThan(3, add1)), + SetMatcher([1, 2, 3]) + ), + test( + "is empty", + natSet.isEmpty(buildTestSet()), + M.equals(T.bool(false)) + ), + test( + "max", + natSet.max(buildTestSet()), + M.equals(T.optional(entryTestable, ?2)) + ), + test( + "min", + natSet.min(buildTestSet()), + M.equals(T.optional(entryTestable, ?0)) + ), + test( + "all true", + natSet.all(buildTestSet(), func (k) = (k >= 0)), + M.equals(T.bool(true)) + ), + test( + "all false", + natSet.all(buildTestSet(), func (k) = (k > 0)), + M.equals(T.bool(false)) + ), + test( + "some true", + natSet.some(buildTestSet(), func (k) = (k >= 2)), + M.equals(T.bool(true)) + ), + test( + "some false", + natSet.some(buildTestSet(), func (k) = (k > 2)), + M.equals(T.bool(false)) + ), + ]; + +buildTestSet := func() : Set.Set { + var rbSet = natSet.empty(); + rbSet := insert(rbSet, 2); + rbSet := insert(rbSet, 1); + rbSet := insert(rbSet, 0); + rbSet +}; + +run(suite("rebalance left, left", rebalanceTests(buildTestSet))); + +/* --------------------------------------- */ + +buildTestSet := func() : Set.Set { + var rbSet = natSet.empty(); + rbSet := insert(rbSet, 2); + rbSet := insert(rbSet, 0); + rbSet := insert(rbSet, 1); + rbSet +}; + +run(suite("rebalance left, right", rebalanceTests(buildTestSet))); + +/* --------------------------------------- */ + +buildTestSet := func() : Set.Set { + var rbSet = natSet.empty(); + rbSet := insert(rbSet, 0); + rbSet := insert(rbSet, 2); + rbSet := insert(rbSet, 1); + rbSet +}; + +run(suite("rebalance right, left", rebalanceTests(buildTestSet))); + +/* --------------------------------------- */ + +buildTestSet := func() : Set.Set { + var rbSet = natSet.empty(); + rbSet := insert(rbSet, 0); + rbSet := insert(rbSet, 1); + rbSet := insert(rbSet, 2); + rbSet +}; + +run(suite("rebalance right, right", rebalanceTests(buildTestSet))); + +/* --------------------------------------- */ + +run( + suite( + "repeated operations", + [ + test( + "repeated insert", + do { + var rbSet = buildTestSet(); + assert (natSet.contains(rbSet, 1)); + rbSet := natSet.put(rbSet, 1); + natSet.size(rbSet) + }, + M.equals(T.nat(3)) + ), + test( + "repeated delete", + do { + var rbSet = buildTestSet(); + rbSet := natSet.delete(rbSet, 1); + natSet.delete(rbSet, 1) + }, + SetMatcher([0, 2]) + ) + ] + ) +); + +/* --------------------------------------- */ + +let buildTestSet012 = func() : Set.Set { + var rbSet = natSet.empty(); + rbSet := insert(rbSet, 0); + rbSet := insert(rbSet, 1); + rbSet := insert(rbSet, 2); + rbSet +}; + +let buildTestSet01 = func() : Set.Set { + var rbSet = natSet.empty(); + rbSet := insert(rbSet, 0); + rbSet := insert(rbSet, 1); + rbSet +}; + +let buildTestSet234 = func() : Set.Set { + var rbSet = natSet.empty(); + rbSet := insert(rbSet, 2); + rbSet := insert(rbSet, 3); + rbSet := insert(rbSet, 4); + rbSet +}; + +let buildTestSet345 = func() : Set.Set { + var rbSet = natSet.empty(); + rbSet := insert(rbSet, 5); + rbSet := insert(rbSet, 3); + rbSet := insert(rbSet, 4); + rbSet +}; + +run( + suite( + "set operations", + [ + test( + "subset/subset of itself", + natSet.isSubset(buildTestSet012(), buildTestSet012()), + M.equals(T.bool(true)) + ), + test( + "subset/empty set is subset of itself", + natSet.isSubset(natSet.empty(), natSet.empty()), + M.equals(T.bool(true)) + ), + test( + "subset/empty set is subset of another set", + natSet.isSubset(natSet.empty(), buildTestSet012()), + M.equals(T.bool(true)) + ), + test( + "subset/subset", + natSet.isSubset(buildTestSet01(), buildTestSet012()), + M.equals(T.bool(true)) + ), + test( + "subset/not subset", + natSet.isSubset(buildTestSet012(), buildTestSet01()), + M.equals(T.bool(false)) + ), + test( + "equals/empty set", + natSet.equals(natSet.empty(), natSet.empty()), + M.equals(T.bool(true)) + ), + test( + "equals/equals", + natSet.equals(buildTestSet012(), buildTestSet012()), + M.equals(T.bool(true)) + ), + test( + "equals/not equals", + natSet.equals(buildTestSet012(), buildTestSet01()), + M.equals(T.bool(false)) + ), + test( + "union/empty set", + natSet.union(natSet.empty(), natSet.empty()), + SetMatcher([]) + ), + test( + "union/union with empty set", + natSet.union(buildTestSet012(), natSet.empty()), + SetMatcher([0, 1, 2]) + ), + test( + "union/union with itself", + natSet.union(buildTestSet012(), buildTestSet012()), + SetMatcher([0, 1, 2]) + ), + test( + "union/union with subset", + natSet.union(buildTestSet012(), buildTestSet01()), + SetMatcher([0, 1, 2]) + ), + test( + "union/union expand", + natSet.union(buildTestSet012(), buildTestSet234()), + SetMatcher([0, 1, 2, 3, 4]) + ), + test( + "intersect/empty set", + natSet.intersect(natSet.empty(), natSet.empty()), + SetMatcher([]) + ), + test( + "intersect/intersect with empty set", + natSet.intersect(buildTestSet012(), natSet.empty()), + SetMatcher([]) + ), + test( + "intersect/intersect with itself", + natSet.intersect(buildTestSet012(), buildTestSet012()), + SetMatcher([0, 1, 2]) + ), + test( + "intersect/intersect with subset", + natSet.intersect(buildTestSet012(), buildTestSet01()), + SetMatcher([0, 1]) + ), + test( + "intersect/intersect", + natSet.intersect(buildTestSet012(), buildTestSet234()), + SetMatcher([2]) + ), + test( + "intersect/no intersection", + natSet.intersect(buildTestSet012(), buildTestSet345()), + SetMatcher([]) + ), + test( + "diff/empty set", + natSet.diff(natSet.empty(), natSet.empty()), + SetMatcher([]) + ), + test( + "diff/diff with empty set", + natSet.diff(buildTestSet012(), natSet.empty()), + SetMatcher([0, 1, 2]) + ), + test( + "diff/diff with empty set 2", + natSet.diff(natSet.empty(), buildTestSet012()), + SetMatcher([]) + ), + test( + "diff/diff with subset", + natSet.diff(buildTestSet012(), buildTestSet01()), + SetMatcher([2]) + ), + test( + "diff/diff with subset 2", + natSet.diff(buildTestSet01(), buildTestSet012()), + SetMatcher([]) + ), + test( + "diff/diff", + natSet.diff(buildTestSet012(), buildTestSet234()), + SetMatcher([0, 1]) + ), + test( + "diff/diff no intersection", + natSet.diff(buildTestSet012(), buildTestSet345()), + SetMatcher([0, 1, 2]) + ), + ] + ) +);