Skip to content

Commit

Permalink
make union_find compile with prelude
Browse files Browse the repository at this point in the history
  • Loading branch information
zapashcanon committed Feb 10, 2025
1 parent 6819918 commit a45fb78
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 50 deletions.
93 changes: 43 additions & 50 deletions src/data_structures/union_find.ml
Original file line number Diff line number Diff line change
Expand Up @@ -47,51 +47,47 @@ module Make (X : VariableType) : S with type key = X.t = struct
type key = X.t

type 'a node =
{ aliases : SX.t;
cardinal : int;
datum : 'a option
{ aliases : SX.t
; cardinal : int
; datum : 'a option
}

type 'a t =
{ canonical_elements : X.t MX.t;
node_of_canonicals : 'a node MX.t
{ canonical_elements : X.t MX.t
; node_of_canonicals : 'a node MX.t
}

let print_set ppf set =
if SX.is_empty set
then Format.fprintf ppf "{}"
if SX.is_empty set then Fmt.pf ppf "{}"
else (
Format.fprintf ppf "@[<hov 1>{";
Fmt.pf ppf "@[<hov 1>{";
let first = ref true in
SX.iter
(fun x ->
if !first then first := false else Format.fprintf ppf ",@ ";
X.print ppf x)
if !first then first := false else Fmt.pf ppf ",@ ";
X.print ppf x )
set;
Format.fprintf ppf "}@]")
Fmt.pf ppf "}@]" )

let print_map pp ppf map =
if MX.is_empty map
then Format.fprintf ppf "{}"
if MX.is_empty map then Fmt.pf ppf "{}"
else (
Format.fprintf ppf "@[<hov 1>{";
Fmt.pf ppf "@[<hov 1>{";
let first = ref true in
MX.iter
(fun key value ->
if !first then first := false else Format.fprintf ppf ",@ ";
Format.fprintf ppf "@[<hov 1>(%a@ %a)@]" X.print key pp value)
if !first then first := false else Fmt.pf ppf ",@ ";
Fmt.pf ppf "@[<hov 1>(%a@ %a)@]" X.print key pp value )
map;
Format.fprintf ppf "}@]")
Fmt.pf ppf "}@]" )

let print_aliases ppf { aliases; _ } = print_set ppf aliases

let print_datum pp ppf { datum; _ } =
Format.pp_print_option
~none:(fun ppf () -> Format.fprintf ppf "<default>")
pp ppf datum
Fmt.option ~none:(fun ppf () -> Fmt.pf ppf "<default>") pp ppf datum

let[@ocamlformat "disable"] print pp ppf { node_of_canonicals; _ } =
Format.fprintf ppf
Fmt.pf ppf
"@[<hov 1>(\
@[<hov 1>(aliases_of_canonicals@ %a)@]@ \
@[<hov 1>(payload_of_canonicals@ %a)@]\
Expand All @@ -109,26 +105,26 @@ module Make (X : VariableType) : S with type key = X.t = struct
let add ~merge variable datum t =
let variable = find_canonical variable t in
let node_of_canonicals =
MX.update variable (function
| None ->
Some { aliases = SX.empty ; cardinal = 0 ; datum = Some datum }
| Some node ->
let datum =
match node.datum with
| None -> Some datum
| Some existing_datum ->
Some (merge datum existing_datum)
in
Some { node with datum })
t.node_of_canonicals
MX.update variable
(function
| None ->
Some { aliases = SX.empty; cardinal = 0; datum = Some datum }
| Some node ->
let datum =
match node.datum with
| None -> Some datum
| Some existing_datum -> Some (merge datum existing_datum)
in
Some { node with datum } )
t.node_of_canonicals
in
{ t with node_of_canonicals }

let find_node_opt canonical t = MX.find_opt canonical t.node_of_canonicals

let find_node canonical t =
match find_node_opt canonical t with
| None -> { aliases = SX.empty ; cardinal = 0 ; datum = None }
| None -> { aliases = SX.empty; cardinal = 0; datum = None }
| Some node -> node

let find_opt variable t =
Expand All @@ -138,39 +134,36 @@ module Make (X : VariableType) : S with type key = X.t = struct

let set_canonical_element aliases canonical canonical_elements =
SX.fold
(fun alias canonical_elements ->
MX.add alias canonical canonical_elements)
(fun alias canonical_elements -> MX.add alias canonical canonical_elements)
aliases canonical_elements

let union ~merge lhs rhs t =
let lhs = find_canonical lhs t in
let rhs = find_canonical rhs t in
if X.equal lhs rhs
then t
if X.equal lhs rhs then t
else
let lhs_node = find_node lhs t in
let rhs_node = find_node rhs t in
let demoted, canonical, canonical_elements =
if lhs_node.cardinal < rhs_node.cardinal
then
( lhs,
rhs,
set_canonical_element lhs_node.aliases rhs t.canonical_elements )
if lhs_node.cardinal < rhs_node.cardinal then
( lhs
, rhs
, set_canonical_element lhs_node.aliases rhs t.canonical_elements )
else
( rhs,
lhs,
set_canonical_element rhs_node.aliases lhs t.canonical_elements )
( rhs
, lhs
, set_canonical_element rhs_node.aliases lhs t.canonical_elements )
in
let datum =
match lhs_node.datum, rhs_node.datum with
match (lhs_node.datum, rhs_node.datum) with
| None, None -> None
| None, Some datum | Some datum, None -> Some datum
| Some lhs_datum, Some rhs_datum -> Some (merge lhs_datum rhs_datum)
in
let node =
{ aliases = SX.union lhs_node.aliases rhs_node.aliases;
cardinal = lhs_node.cardinal + rhs_node.cardinal + 1;
datum
{ aliases = SX.union lhs_node.aliases rhs_node.aliases
; cardinal = lhs_node.cardinal + rhs_node.cardinal + 1
; datum
}
in
let node_of_canonicals = MX.add canonical node t.node_of_canonicals in
Expand Down
1 change: 1 addition & 0 deletions src/dune
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
tracing
trap
types
union_find
value_intf
v
wasm_ffi_intf
Expand Down

0 comments on commit a45fb78

Please sign in to comment.