From e138adb6b091338d4be6cbcc041b356e83237681 Mon Sep 17 00:00:00 2001 From: Dhruv Makwana Date: Sun, 29 Dec 2024 16:28:17 +0000 Subject: [PATCH] CN: Use simpler, custom monad for WellTyped This commit removes functor from around the implementation of WellTyped, and uses a simpler Error and Reader monad based only on Context.t, hence no solver. Like Global, it also provides an transformer functor to lift its exposed monadic API to a give target. --- backend/cn/lib/typing.ml | 78 +++---------- backend/cn/lib/wellTyped.ml | 215 ++++++++++++++++++++++++++++++++++- backend/cn/lib/wellTyped.mli | 22 +++- 3 files changed, 247 insertions(+), 68 deletions(-) diff --git a/backend/cn/lib/typing.ml b/backend/cn/lib/typing.ml index eaecc0909..a0334fa9c 100644 --- a/backend/cn/lib/typing.ml +++ b/backend/cn/lib/typing.ml @@ -195,7 +195,6 @@ let modify_where (f : Where.t -> Where.t) : unit t = { s with log; typing_context }) -(** TODO move the option part of this to Memory *) let get_member_type loc member layout : Sctypes.t m = let member_types = Memory.member_types layout in match List.assoc_opt Id.equal member member_types with @@ -204,22 +203,28 @@ let get_member_type loc member layout : Sctypes.t m = fail (fun _ -> { loc; msg = Unexpected_member (List.map fst member_types, member) }) -module Global = struct - include Global.Lift (struct - type nonrec 'a t = 'a t +module ErrorReader = struct + type nonrec 'a t = 'a t - let return = return + let return = return + + let bind = bind - let bind = bind + type state = s - type state = s + type global = Global.t - type global = Global.t + let get = get - let get = get + let to_global (s : s) = s.typing_context.global - let to_global (s : s) = s.typing_context.global - end) + let to_context (s : s) = s.typing_context + + let lift = lift +end + +module Global = struct + include Global.Lift (ErrorReader) let empty = Global.empty @@ -829,57 +834,8 @@ let test_value_eqs loc guard x ys = loop group ms ys -module NoSolver = struct - type nonrec 'a t = 'a t - - type nonrec failure = failure - - let liftFail typeErr _ = typeErr - - let return = return - - let bind = bind - - let pure = pure - - let fail = fail - - let bound_a = bound_a - - let bound_l = bound_l - - let get_a = get_a - - let get_l = get_l - - let add_a = add_a - - let add_l = add_l - - let get_struct_decl = Global.get_struct_decl - - let get_struct_member_type = Global.get_struct_member_type - - let get_datatype = Global.get_datatype - - let get_datatype_constr = Global.get_datatype_constr - - let get_resource_predicate_def = Global.get_resource_predicate_def - - let get_logical_function_def = Global.get_logical_function_def - - let get_lemma = Global.get_lemma - - let get_fun_decl = Global.get_fun_decl - - let ensure_base_type = ensure_base_type - - let lift = function Ok x -> return x | Error x -> fail (fun _ -> x) -end - module WellTyped = struct type nonrec 'a t = 'a t - include WellTyped.Make (NoSolver) - include Exposed + include WellTyped.Lift (ErrorReader) end diff --git a/backend/cn/lib/wellTyped.ml b/backend/cn/lib/wellTyped.ml index 2433d43c2..d1a0eee8f 100644 --- a/backend/cn/lib/wellTyped.ml +++ b/backend/cn/lib/wellTyped.ml @@ -9,14 +9,133 @@ let squotes, warn, dot, string, debug, item, colon, comma = Pp.(squotes, warn, dot, string, debug, item, colon, comma) +module GlobalReader = struct + type 'a t = Context.t -> ('a * Context.t) Or_TypeError.t + + let return x s = Ok (x, s) + + let bind x f s = match x s with Ok (y, s') -> f y s' | Error err -> Error err + + let get () s = Ok (s, s) + + let to_global ctxt = ctxt.Context.global + + type global = Global.t + + type state = Context.t +end + +module NoSolver = struct + include GlobalReader + include Global.Lift (GlobalReader) + + type failure = TypeErrors.t + + let liftFail typeErr = typeErr + + let pure x s = match x s with Ok (y, _) -> Ok (y, s) | Error err -> Error err + + let fail (typeErr : failure) : 'a t = fun _ -> Error (liftFail typeErr) + + let update f s = Ok ((), f s) + + let lookup f : _ t = fun s -> Ok (f s, s) + + let ( let@ ) = bind + + let bound_a sym = lookup (Context.bound_a sym) + + let bound_l sym = lookup (Context.bound_l sym) + + let get_a sym = lookup (Context.get_a sym) + + let get_l sym = lookup (Context.get_l sym) + + let add_a sym bt info = update (Context.add_a sym bt info) + + let add_l sym bt info = update (Context.add_l sym bt info) + + let ensure_base_type loc ~expect has : unit t = + if BT.equal has expect then + return () + else + fail { loc; msg = Mismatch { has = BT.pp has; expect = BT.pp expect } } + + + let error_if_none opt loc msg = + let@ opt in + Option.fold + opt + ~some:return + ~none: + (let@ msg in + fail { loc; msg }) + + + let get_logical_function_def_opt id = get_logical_function_def id + + let get_logical_function_def loc id = + error_if_none + (get_logical_function_def id) + loc + (let@ res = get_resource_predicate_def id in + return (TypeErrors.Unknown_logical_function { id; resource = Option.is_some res })) + + + let get_struct_decl loc tag = + error_if_none (get_struct_decl tag) loc (return (TypeErrors.Unknown_struct tag)) + + + let get_datatype loc tag = + error_if_none (get_datatype tag) loc (return (TypeErrors.Unknown_datatype tag)) + + + let get_datatype_constr loc tag = + error_if_none + (get_datatype_constr tag) + loc + (return (TypeErrors.Unknown_datatype_constr tag)) + + + let get_member_type loc member layout : Sctypes.t t = + let member_types = Memory.member_types layout in + match List.assoc_opt Id.equal member member_types with + | Some membertyp -> return membertyp + | None -> fail { loc; msg = Unexpected_member (List.map fst member_types, member) } + + + let get_struct_member_type loc tag member = + let@ decl = get_struct_decl loc tag in + let@ ty = get_member_type loc member decl in + return ty + + + let get_fun_decl loc fsym = + error_if_none (get_fun_decl fsym) loc (return (TypeErrors.Unknown_function fsym)) + + + let get_lemma loc lsym = + error_if_none (get_lemma lsym) loc (return (TypeErrors.Unknown_lemma lsym)) + + + let get_resource_predicate_def loc id = + error_if_none + (get_resource_predicate_def id) + loc + (let@ log = get_logical_function_def_opt id in + return (TypeErrors.Unknown_resource_predicate { id; logical = Option.is_some log })) + + + let lift = function Ok x -> return x | Error x -> fail x +end + let use_ity = ref true -module Make (Monad : Sigs.NoSolver) = struct -open Monad +open NoSolver -let fail typeErr = fail (Monad.liftFail typeErr) +let fail typeErr = fail (NoSolver.liftFail typeErr) -open Effectful.Make (Monad) +open Effectful.Make (NoSolver) let illtyped_index_term (loc : Locations.t) it has ~expected ~reason = let reason = @@ -2394,4 +2513,90 @@ module Exposed = struct let ensure_bits_type = ensure_bits_type end -end[@@ocamlformat "disable"] + +module type ErrorReader = sig + type 'a t + + val return : 'a -> 'a t + + val bind : 'a t -> ('a -> 'b t) -> 'b t + + type state + + val get : unit -> state t + + val to_context : state -> Context.t + + val lift : 'a Or_TypeError.t -> 'a t +end + +module Lift (M : ErrorReader) : Sigs.Exposed with type 'a t := 'a M.t = struct + let lift1 f x = + let ( let@ ) = M.bind in + let@ state = M.get () in + let context = M.to_context state in + M.lift (Result.map fst (f x context)) + + + let lift2 f x y = + let ( let@ ) = M.bind in + let@ state = M.get () in + let context = M.to_context state in + M.lift (Result.map fst (f x y context)) + + + let lift3 f x y z = + let ( let@ ) = M.bind in + let@ state = M.get () in + let context = M.to_context state in + M.lift (Result.map fst (f x y z context)) + + + let datatype x = lift1 Exposed.datatype x + + let datatype_recursion = lift1 Exposed.datatype_recursion + + let lemma x y z = lift3 Exposed.lemma x y z + + let function_ = lift1 Exposed.function_ + + let predicate = lift1 Exposed.predicate + + let label_context = Exposed.label_context + + let to_argument_type = Exposed.to_argument_type + + let procedure x y = lift2 Exposed.procedure x y + + let integer_annot = Exposed.integer_annot + + let infer_expr x y = lift2 Exposed.infer_expr x y + + let check_expr x y z = lift3 Exposed.check_expr x y z + + let function_type = lift3 Exposed.function_type + + let logical_constraint = lift2 Exposed.logical_constraint + + let oarg_bt_of_pred = lift2 Exposed.oarg_bt_of_pred + + let default_quantifier_bt = Exposed.default_quantifier_bt + + let infer_term x = lift1 Exposed.infer_term x + + let check_term x y z = lift3 Exposed.check_term x y z + + let check_ct = lift2 Exposed.check_ct + + let compare_by_fst_id = Exposed.compare_by_fst_id + + let ensure_same_argument_number loc type_ n ~expect = + let ( let@ ) = M.bind in + let@ state = M.get () in + let context = M.to_context state in + M.lift + (Result.map fst (Exposed.ensure_same_argument_number loc type_ n ~expect context)) + + + let ensure_bits_type = lift2 Exposed.ensure_bits_type +end diff --git a/backend/cn/lib/wellTyped.mli b/backend/cn/lib/wellTyped.mli index 38d28f3d2..ef35fc3a2 100644 --- a/backend/cn/lib/wellTyped.mli +++ b/backend/cn/lib/wellTyped.mli @@ -1,5 +1,23 @@ val use_ity : bool ref -module Make : functor (Monad : Sigs.NoSolver) -> sig - module Exposed : Sigs.Exposed with type 'a t := 'a Monad.t +module NoSolver : Sigs.NoSolver + +module Exposed : Sigs.Exposed with type 'a t := 'a NoSolver.t + +module type ErrorReader = sig + type 'a t + + val return : 'a -> 'a t + + val bind : 'a t -> ('a -> 'b t) -> 'b t + + type state + + val get : unit -> state t + + val to_context : state -> Context.t + + val lift : 'a Or_TypeError.t -> 'a t end + +module Lift : functor (M : ErrorReader) -> Sigs.Exposed with type 'a t := 'a M.t