Skip to content

Commit

Permalink
CN: Factor out Global lookup functions
Browse files Browse the repository at this point in the history
This was ostensibly in prep for using it in WellTyped, but it turned out
that Typing also benefits from a bit of tidying around this module too.
  • Loading branch information
dc-mak committed Dec 29, 2024
1 parent c0a3a57 commit fc5fc5f
Show file tree
Hide file tree
Showing 5 changed files with 244 additions and 167 deletions.
8 changes: 4 additions & 4 deletions backend/cn/lib/cLogicalFuns.ml
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ let rec symb_exec_expr ctxt state_vars expr =
in
if Sym.Map.mem nm ctxt.c_fun_pred_map then (
let loc, l_sym = Sym.Map.find nm ctxt.c_fun_pred_map in
let@ def = get_logical_function_def loc l_sym in
let@ def = Global.get_logical_function_def loc l_sym in
rcval (IT.apply_ l_sym args_its def.Definition.Function.return_bt loc) state)
else (
let bail = fail_fun_it "not a function with a pure/logical interpretation" in
Expand Down Expand Up @@ -710,9 +710,9 @@ let c_fun_to_it id_loc glob_context (id : Sym.t) fsym def (fn : 'bty Mu.fun_map_

let upd_def (loc, sym, def_tm) =
let open Definition.Function in
let@ def = get_logical_function_def loc sym in
let@ def = Global.get_logical_function_def loc sym in
match def.body with
| Uninterp -> add_logical_function sym { def with body = Def def_tm }
| Uninterp -> Global.add_logical_function sym { def with body = Def def_tm }
| _ ->
fail_n
{ loc;
Expand All @@ -734,7 +734,7 @@ let add_logical_funs_from_c call_funinfo funs_to_convert funs =
let@ conv_defs =
ListM.mapM
(fun Mu.{ c_fun_sym; loc; l_fun_sym } ->
let@ def = get_logical_function_def loc l_fun_sym in
let@ def = Global.get_logical_function_def loc l_fun_sym in
let@ fbody =
match Pmap.lookup c_fun_sym funs with
| Some fbody -> return fbody
Expand Down
52 changes: 26 additions & 26 deletions backend/cn/lib/check.ml
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ let check_ptrval (loc : Locations.t) ~(expect : BT.t) (ptrval : pointer_value) :
unsupported loc !^"invalid function pointer"
| Some sym ->
(* just to make sure it exists *)
let@ _fun_loc, _, _ = get_fun_decl loc sym in
let@ _fun_loc, _, _ = Global.get_fun_decl loc sym in
(* the symbol of a function is the same as the symbol of its address *)
let here = Locations.other __LOC__ in
return (sym_ (sym, BT.(Loc ()), here)))
Expand Down Expand Up @@ -158,7 +158,7 @@ and check_struct
(member_values : (Id.t * Sctypes.t * mem_value) list)
: IT.t m
=
let@ layout = get_struct_decl loc tag in
let@ layout = Global.get_struct_decl loc tag in
let member_types = Memory.member_types layout in
assert (
List.for_all2
Expand Down Expand Up @@ -248,7 +248,7 @@ let rec check_value (loc : Locations.t) (Mu.V (expect, v)) : IT.t m =
| Vfunction_addr sym ->
let@ () = ensure_base_type loc ~expect (Loc ()) in
(* check it is a valid function address *)
let@ _ = get_fun_decl loc sym in
let@ _ = Global.get_fun_decl loc sym in
return (IT.sym_ (sym, BT.(Loc ()), loc))
| Vlist (_item_cbt, vals) ->
let item_bt = Mu.bt_of_value (List.hd vals) in
Expand Down Expand Up @@ -337,7 +337,7 @@ let check_single_ct loc expr =
let is_fun_addr global t =
match IT.is_sym t with
| Some (s, _) ->
if Sym.Map.mem s global.Global.fun_decls then
if Global.is_fun_decl global s then
Some s
else
None
Expand All @@ -351,7 +351,7 @@ let known_function_pointer loc p =
match already_known with
| Some _ -> (* no need to find more eqs *) return ()
| None ->
let global_funs = Sym.Map.bindings global.Global.fun_decls in
let@ global_funs = Global.get_fun_decls () in
let fun_addrs =
List.map (fun (sym, (loc, _, _)) -> IT.sym_ (sym, BT.(Loc ()), loc)) global_funs
in
Expand Down Expand Up @@ -611,7 +611,7 @@ let rec check_pexpr (pe : BT.t Mu.pexpr) (k : IT.t -> unit m) : unit m =
let@ () = ensure_base_type loc ~expect (Loc ()) in
let@ () = ensure_base_type loc ~expect:(Loc ()) (Mu.bt_of_pexpr pe) in
check_pexpr pe (fun vt ->
let@ ct = get_struct_member_type loc tag member in
let@ ct = Global.get_struct_member_type loc tag member in
let result = memberShift_ (vt, tag, member) loc in
(* This should only be called after a PtrValidForDeref, so if we
were willing to optimise, we could skip to [k result]. *)
Expand Down Expand Up @@ -759,7 +759,7 @@ let rec check_pexpr (pe : BT.t Mu.pexpr) (k : IT.t -> unit m) : unit m =
| PEstruct (tag, xs) ->
let@ () = WellTyped.check_ct loc (Struct tag) in
let@ () = ensure_base_type loc ~expect (Struct tag) in
let@ layout = get_struct_decl loc tag in
let@ layout = Global.get_struct_decl loc tag in
let member_types = Memory.member_types layout in
let@ _ =
ListM.map2M
Expand All @@ -781,7 +781,7 @@ let rec check_pexpr (pe : BT.t Mu.pexpr) (k : IT.t -> unit m) : unit m =
(* function vals are just symbols the same as the names of functions *)
let@ sym = known_function_pointer loc ptr in
(* need to conjure up the characterising 4-tuple *)
let@ _, _, c_sig = get_fun_decl loc sym in
let@ _, _, c_sig = Global.get_fun_decl loc sym in
match IT.const_of_c_sig c_sig loc with
| Some it -> k it
| None ->
Expand Down Expand Up @@ -1712,7 +1712,7 @@ let rec check_expr labels (e : BT.t Mu.expr) (k : IT.t -> unit m) : unit m =
check_pexpr f_pe (fun f_it ->
let@ _global = get_global () in
let@ fsym = known_function_pointer loc f_it in
let@ _loc, opt_ft, _ = get_fun_decl loc fsym in
let@ _loc, opt_ft, _ = Global.get_fun_decl loc fsym in
let@ ft =
match opt_ft with
| Some ft -> return ft
Expand Down Expand Up @@ -1876,7 +1876,7 @@ let rec check_expr labels (e : BT.t Mu.expr) (k : IT.t -> unit m) : unit m =
match to_instantiate with
| I_Everything -> return (fun _ -> true)
| I_Function f ->
let@ _ = get_logical_function_def loc f in
let@ _ = Global.get_logical_function_def loc f in
return (IT.mentions_call f)
| I_Good ct ->
let@ () = WellTyped.check_ct loc ct in
Expand Down Expand Up @@ -1904,7 +1904,7 @@ let rec check_expr labels (e : BT.t Mu.expr) (k : IT.t -> unit m) : unit m =
let@ () = WellTyped.check_ct loc ct in
return (Request.Owned (ct, Uninit))
| E_Pred (CN_named pn) ->
let@ _ = get_resource_predicate_def loc pn in
let@ _ = Global.get_resource_predicate_def loc pn in
return (Request.PName pn)
in
let@ it = WellTyped.infer_term it in
Expand All @@ -1922,7 +1922,7 @@ let rec check_expr labels (e : BT.t Mu.expr) (k : IT.t -> unit m) : unit m =
();
return ()
| Unfold (f, args) ->
let@ def = get_logical_function_def loc f in
let@ def = Global.get_logical_function_def loc f in
let has_args, expect_args = (List.length args, List.length def.args) in
let@ () =
WellTyped.ensure_same_argument_number
Expand All @@ -1947,7 +1947,7 @@ let rec check_expr labels (e : BT.t Mu.expr) (k : IT.t -> unit m) : unit m =
| Some body ->
add_c loc (LC.T (eq_ (apply_ f args def.return_bt loc, body) loc)))
| Apply (lemma, args) ->
let@ _loc, lemma_typ = get_lemma loc lemma in
let@ _loc, lemma_typ = Global.get_lemma loc lemma in
let args = List.map (fun arg -> (loc, arg)) args in
Spine.calltype_lemma loc ~lemma args lemma_typ (fun lrt ->
let@ _, members =
Expand Down Expand Up @@ -2171,7 +2171,7 @@ let record_tagdefs tagDefs =
(fun tag def ->
match def with
| Mu.UnionDef -> unsupported (Loc.other __LOC__) !^"todo: union types"
| StructDef layout -> add_struct_decl tag layout)
| StructDef layout -> Global.add_struct_decl tag layout)
tagDefs


Expand Down Expand Up @@ -2211,7 +2211,7 @@ let record_and_check_logical_functions funs =
ListM.iterM
(fun (name, def) ->
let@ simple_def = WellTyped.function_ { def with body = Uninterp } in
add_logical_function name simple_def)
Global.add_logical_function name simple_def)
recursive
in
(* Now check all functions in order. *)
Expand All @@ -2226,7 +2226,7 @@ let record_and_check_logical_functions funs =
^ ": "
^ Sym.pp_string name)));
let@ def = WellTyped.function_ def in
add_logical_function name def)
Global.add_logical_function name def)
funs


Expand All @@ -2236,7 +2236,7 @@ let record_and_check_resource_predicates preds =
ListM.iterM
(fun (name, def) ->
let@ simple_def = WellTyped.predicate { def with clauses = None } in
add_resource_predicate name simple_def)
Global.add_resource_predicate name simple_def)
preds
in
ListM.iteriM
Expand All @@ -2251,7 +2251,7 @@ let record_and_check_resource_predicates preds =
^ Sym.pp_string name)));
let@ def = WellTyped.predicate def in
(* add simplified def to the context *)
add_resource_predicate name def)
Global.add_resource_predicate name def)
preds


Expand Down Expand Up @@ -2331,7 +2331,7 @@ let wf_check_and_record_functions funs call_sigs =
let ft = WellTyped.to_argument_type args_and_body in
debug 6 (lazy (!^"function type" ^^^ Sym.pp fsym));
debug 6 (lazy (CF.Pp_ast.pp_doc_tree (AT.dtree RT.dtree ft)));
let@ () = add_fun_decl fsym (loc, Some ft, Pmap.find fsym call_sigs) in
let@ () = Global.add_fun_decl fsym (loc, Some ft, Pmap.find fsym call_sigs) in
(match tr with
| Trusted _ -> return ((fsym, (loc, ft)) :: trusted, checked)
| Checked -> return (trusted, (fsym, (loc, args_and_body)) :: checked))
Expand All @@ -2344,7 +2344,7 @@ let wf_check_and_record_functions funs call_sigs =
let@ ft = WellTyped.function_type "function" loc ft in
return (Some ft)
in
let@ () = add_fun_decl fsym (loc, oft, Pmap.find fsym call_sigs) in
let@ () = Global.add_fun_decl fsym (loc, oft, Pmap.find fsym call_sigs) in
return (trusted, checked))
funs
([], [])
Expand Down Expand Up @@ -2456,7 +2456,7 @@ let check_c_functions (funs : c_function list) : (string * TypeErrors.t) list m

let wf_check_and_record_lemma (lemma_s, (loc, lemma_typ)) =
let@ lemma_typ = WellTyped.lemma loc lemma_s lemma_typ in
let@ () = add_lemma lemma_s (loc, lemma_typ) in
let@ () = Global.add_lemma lemma_s (loc, lemma_typ) in
return (lemma_s, (loc, lemma_typ))


Expand Down Expand Up @@ -2566,7 +2566,7 @@ let add_stdlib_spec =
Pp.debug
2
(lazy (Pp.headline ("adding builtin spec for procedure " ^ Sym.pp_string fsym)));
add_fun_decl fsym (Locations.other __LOC__, Some ft, ct)
Global.add_fun_decl fsym (Locations.other __LOC__, Some ft, ct)
in
fun call_sigs fsym ->
match
Expand All @@ -2590,23 +2590,23 @@ let record_and_check_datatypes datatypes =
let@ () =
ListM.iterM
(fun (s, Mu.{ loc = _; cases = _ }) ->
add_datatype s { constrs = []; all_params = [] })
Global.add_datatype s { constrs = []; all_params = [] })
datatypes
in
(* check and normalise datatypes *)
let@ datatypes = ListM.mapM WellTyped.datatype datatypes in
let@ sccs = WellTyped.datatype_recursion datatypes in
let@ () = set_datatype_order (Some sccs) in
let@ () = Global.set_datatype_order (Some sccs) in
(* properly add datatypes *)
ListM.iterM
(fun (s, Mu.{ loc = _; cases }) ->
let@ () =
add_datatype
Global.add_datatype
s
{ constrs = List.map fst cases; all_params = List.concat_map snd cases }
in
ListM.iterM
(fun (c, params) -> add_datatype_constr c { params; datatype_tag = s })
(fun (c, params) -> Global.add_datatype_constr c { params; datatype_tag = s })
cases)
datatypes

Expand Down
65 changes: 65 additions & 0 deletions backend/cn/lib/global.ml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,75 @@ let get_fun_decl global sym = Sym.Map.find_opt sym global.fun_decls

let get_lemma global sym = Sym.Map.find_opt sym global.lemmata

let get_struct_decl global sym = Sym.Map.find_opt sym global.struct_decls

let get_datatype global sym = Sym.Map.find_opt sym global.datatypes

let get_datatype_constr global sym = Sym.Map.find_opt sym global.datatype_constrs

let sym_map_from_bindings xs =
List.fold_left (fun m (nm, x) -> Sym.Map.add nm x m) Sym.Map.empty xs


module type Reader = sig
type global = t

type 'a t

val return : 'a -> 'a t

val bind : 'a t -> ('a -> 'b t) -> 'b t

type state

val get : unit -> state t

val to_global : state -> global
end

module type Lifted = sig
type 'a t

val get_resource_predicate_def : Sym.t -> Definition.Predicate.t option t

val get_logical_function_def : Sym.t -> Definition.Function.t option t

val get_fun_decl
: Sym.t ->
(Cerb_location.t * AT.ft option * Sctypes.c_concrete_sig) option t

val get_lemma : Sym.t -> (Cerb_location.t * AT.lemmat) option t

val get_struct_decl : Sym.t -> Memory.struct_layout option t

val get_datatype : Sym.t -> BaseTypes.dt_info option t

val get_datatype_constr : Sym.t -> BaseTypes.constr_info option t
end

module Lift (M : Reader) : Lifted with type 'a t := 'a M.t = struct
let lift f sym =
let ( let@ ) = M.bind in
let@ state = M.get () in
let global = M.to_global state in
M.return (f global sym)


let get_resource_predicate_def = lift get_resource_predicate_def

let get_logical_function_def = lift get_logical_function_def

let get_fun_decl = lift get_fun_decl

let get_lemma = lift get_lemma

let get_struct_decl = lift get_struct_decl

let get_datatype = lift get_datatype

let get_datatype_constr = lift get_datatype_constr
end

let pp_struct_layout (tag, layout) =
item
("struct " ^ plain (Sym.pp tag) ^ " (raw)")
Expand Down
Loading

0 comments on commit fc5fc5f

Please sign in to comment.