diff --git a/src/sail_lean_backend/Sail/Sail.lean b/src/sail_lean_backend/Sail/Sail.lean index dc3dc9f92..965b9f890 100644 --- a/src/sail_lean_backend/Sail/Sail.lean +++ b/src/sail_lean_backend/Sail/Sail.lean @@ -1,7 +1,42 @@ +import Std.Data.DHashMap namespace Sail -/- Placeholder for a future implementation of the state monad some Sail functions use. -/ -abbrev SailM := StateM Unit +section Regs + +variable {Register : Type} {RegisterType : Register → Type} [DecidableEq Register] [Hashable Register] + +/- The Units are placeholders for a future implementation of the state monad some Sail functions use. -/ +abbrev Error := Unit + +structure SequentialState (RegisterType : Register → Type) where + regs : Std.DHashMap Register RegisterType + mem : Unit + tags : Unit + +inductive RegisterRef (RegisterType : Register → Type) : Type → Type where + | Reg (r : Register) : RegisterRef _ (RegisterType r) + +abbrev PreSailM (RegisterType : Register → Type) := + EStateM Error (SequentialState RegisterType) + +def writeReg (r : Register) (v : RegisterType r) : PreSailM RegisterType Unit := + modify fun s => { s with regs := s.regs.insert r v } + +def readReg (r : Register) : PreSailM RegisterType (RegisterType r) := do + let .some s := (← get).regs.get? r + | throw () + pure s + +def readRegRef (reg_ref : @RegisterRef Register RegisterType α) : PreSailM RegisterType α := do + match reg_ref with | .Reg r => readReg r + +def writeRegRef (reg_ref : @RegisterRef Register RegisterType α) (a : α) : + PreSailM RegisterType Unit := do + match reg_ref with | .Reg r => writeReg r a + +def reg_deref (reg_ref : @RegisterRef Register RegisterType α) := readRegRef reg_ref + +end Regs namespace BitVec @@ -32,13 +67,3 @@ def updateSubrange {w : Nat} (x : BitVec w) (hi lo : Nat) (y : BitVec (hi - lo + end BitVec end Sail - -structure RegisterRef (regstate regval a : Type) where - name : String - read_from : regstate -> a - write_to : a -> regstate -> regstate - of_regval : regval -> Option a - regval_of : a -> regval - -def undefined_bitvector (w : Nat) : BitVec w := - 0 diff --git a/src/sail_lean_backend/pretty_print_lean.ml b/src/sail_lean_backend/pretty_print_lean.ml index a4fb34dd1..9cfb18e32 100644 --- a/src/sail_lean_backend/pretty_print_lean.ml +++ b/src/sail_lean_backend/pretty_print_lean.ml @@ -152,8 +152,7 @@ let rec doc_typ ctx (Typ_aux (t, _) as typ) = parens (string "BitVec " ^^ doc_nexp ctx m) | Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp x, _)]) -> if provably_nneg ctx x then string "Nat" else string "Int" | Typ_app (Id_aux (Id "register", _), t_app) -> - string "RegisterRef Unit Unit " - (* TODO: Replace units with real types. *) ^^ separate_map comma (doc_typ_app ctx) t_app + string "RegisterRef RegisterType " ^^ separate_map comma (doc_typ_app ctx) t_app | Typ_app (Id_aux (Id "implicit", _), [A_aux (A_nexp (Nexp_aux (Nexp_var ki, _)), _)]) -> underscore (* TODO check if the type of implicit arguments can really be always inferred *) | Typ_tuple ts -> parens (separate_map (space ^^ string "×" ^^ space) (doc_typ ctx) ts) @@ -273,6 +272,61 @@ let string_of_exp_con (E_aux (e, _)) = | E_vector _ -> "E_vector" | E_let _ -> "E_let" +let string_of_pat_con (P_aux (p, _)) = + match p with + | P_app _ -> "P_app" + | P_wild -> "P_wild" + | P_lit _ -> "P_lit" + | P_or _ -> "P_or" + | P_not _ -> "P_not" + | P_as _ -> "P_as" + | P_typ _ -> "P_typ" + | P_id _ -> "P_id" + | P_var _ -> "P_var" + | P_vector _ -> "P_vector" + | P_vector_concat _ -> "P_vector_concat" + | P_vector_subrange _ -> "P_vector_subrange" + | P_tuple _ -> "P_tuple" + | P_list _ -> "P_list" + | P_cons _ -> "P_cons" + | P_string_append _ -> "P_string_append" + | P_struct _ -> "P_struct" + +let rec doc_pat ctxt apat_needed (P_aux (p, (l, annot)) as pat) = + let env = env_of_annot (l, annot) in + let typ = Env.expand_synonyms env (typ_of_annot (l, annot)) in + match p with + | P_typ (ptyp, p) -> + let doc_p = doc_pat ctxt true p in + doc_p + | P_id id -> doc_id_ctor id + | P_wild -> underscore + | _ -> failwith ("Pattern " ^ string_of_pat_con pat ^ " " ^ string_of_pat pat ^ " not translatable yet.") + +(* Copied from the Coq PP *) +let rebind_cast_pattern_vars pat typ exp = + let rec aux pat typ = + match (pat, typ) with + | P_aux (P_typ (target_typ, P_aux (P_id id, (l, ann))), _), source_typ when not (is_enum (env_of exp) id) -> + if Typ.compare target_typ source_typ == 0 then [] + else ( + let l = Parse_ast.Generated l in + let cast_annot = Type_check.replace_typ source_typ ann in + let e_annot = Type_check.mk_tannot (env_of exp) source_typ in + [LB_aux (LB_val (pat, E_aux (E_id id, (l, e_annot))), (l, ann))] + ) + | P_aux (P_tuple pats, _), Typ_aux (Typ_tuple typs, _) -> List.concat (List.map2 aux pats typs) + | _ -> [] + in + let add_lb (E_aux (_, ann) as exp) lb = E_aux (E_let (lb, exp), ann) in + (* Don't introduce new bindings at the top-level, we'd just go into a loop. *) + let lbs = + match (pat, typ) with + | P_aux (P_tuple pats, _), Typ_aux (Typ_tuple typs, _) -> List.concat (List.map2 aux pats typs) + | _ -> [] + in + List.fold_left add_lb exp lbs + let wrap_with_pure (needs_return : bool) (d : document) = if needs_return then parens (nest 2 (flow space [string "pure"; d])) else d @@ -291,16 +345,30 @@ let rec doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) = in match e with | E_id id -> - (* TODO replace by a translating via a binding map *) - wrap_with_pure as_monadic (string (string_of_id id)) + if Env.is_register id env then string "readReg " ^^ doc_id_ctor id + else wrap_with_pure as_monadic (string (string_of_id id)) | E_lit l -> wrap_with_pure as_monadic (doc_lit l) | E_app (Id_aux (Id "undefined_int", _), _) (* TODO remove when we handle imports *) | E_app (Id_aux (Id "undefined_bit", _), _) (* TODO remove when we handle imports *) | E_app (Id_aux (Id "undefined_bitvector", _), _) (* TODO remove when we handle imports *) + | E_app (Id_aux (Id "undefined_bool", _), _) (* TODO remove when we handle imports *) + | E_app (Id_aux (Id "undefined_nat", _), _) (* TODO remove when we handle imports *) | E_app (Id_aux (Id "internal_pick", _), _) -> (* TODO replace by actual implementation of internal_pick *) string "sorry" - | E_internal_plet _ -> string "sorry" (* TODO replace by actual implementation of internal_plet *) + | E_internal_plet (pat, e1, e2) -> + let e0 = doc_pat ctx false pat in + let e1_pp = doc_exp false ctx e1 in + let e2' = rebind_cast_pattern_vars pat (typ_of e1) e2 in + let e2_pp = doc_exp false ctx e2' in + let e0_pp = + begin + match pat with + | P_aux (P_typ (_, P_aux (P_wild, _)), _) -> string "" + | _ -> flow (break 1) [string "let"; e0; string "←"] ^^ space + end + in + nest 2 (e0_pp ^^ e1_pp) ^^ hardline ^^ e2_pp | E_app (f, args) -> let d_id = if Env.is_extern f env "lean" then string (Env.get_extern f env "lean") @@ -339,6 +407,12 @@ let rec doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) = (* TODO *) wrap_with_pure as_monadic (braces (space ^^ doc_exp false ctx exp ^^ string " with " ^^ separate (comma ^^ space) args ^^ space)) + | E_assign ((LE_aux (le_act, tannot) as le), e) -> ( + match le_act with + | LE_id id | LE_typ (_, id) -> string "writeReg " ^^ doc_id_ctor id ^^ space ^^ doc_exp false ctx e + | LE_deref e' -> string "writeRegRef " ^^ doc_exp false ctx e' ^^ space ^^ doc_exp false ctx e + | _ -> failwith ("assign " ^ string_of_lexp le ^ "not implemented yet") + ) | _ -> failwith ("Expression " ^ string_of_exp_con full_exp ^ " " ^ string_of_exp full_exp ^ " not translatable yet.") and doc_fexp with_arrow ctx (FE_aux (FE_fexp (field, e), _)) = @@ -404,7 +478,7 @@ let doc_funcl_body (FCL_aux (FCL_funcl (id, pexp), annot)) = let ctx = initial_context env in let _, _, exp, _ = destruct_pexp pexp in let is_monadic = effectful (effect_of exp) in - doc_exp is_monadic ctx exp + doc_exp is_monadic (initial_context env) exp let doc_funcl ctx funcl = let comment, signature, env = doc_funcl_init funcl in @@ -449,11 +523,16 @@ let doc_typdef ctx (TD_aux (td, tannot) as full_typdef) = nest 2 (flow (break 1) [string "def"; string id; colon; string "Int"; coloneq; doc_nexp ctx ne]) | _ -> failwith ("Type definition " ^ string_of_type_def_con full_typdef ^ " not translatable yet.") -let doc_def ctx (DEF_aux (aux, def_annot) as def) = - match aux with - | DEF_fundef fdef -> group (doc_fundef ctx fdef) ^/^ hardline - | DEF_type tdef -> group (doc_typdef ctx tdef) ^/^ hardline - | _ -> empty +let rec doc_defs_aux ctx defs types fundefs = + match defs with + | [] -> (types, fundefs) + | DEF_aux (DEF_fundef fdef, _) :: defs' -> + doc_defs_aux ctx defs' types (fundefs ^^ group (doc_fundef ctx fdef) ^/^ hardline) + | DEF_aux (DEF_type tdef, _) :: defs' -> + doc_defs_aux ctx defs' (types ^^ group (doc_typdef ctx tdef) ^/^ hardline) fundefs + | _ :: defs' -> doc_defs_aux ctx defs' types fundefs + +let doc_defs ctx defs = doc_defs_aux ctx defs empty empty (* Remove all imports for now, they will be printed in other files. Probably just for testing. *) let rec remove_imports (defs : (Libsail.Type_check.tannot, Libsail.Type_check.env) def list) depth = @@ -463,8 +542,51 @@ let rec remove_imports (defs : (Libsail.Type_check.tannot, Libsail.Type_check.en | DEF_aux (DEF_pragma ("include_end", _, _), _) :: ds -> remove_imports ds (depth - 1) | d :: ds -> if depth > 0 then remove_imports ds depth else d :: remove_imports ds depth +let opt_cons v = function None -> Some [v] | Some t -> Some (v :: t) + +let reg_type_name typ_id = prepend_id "register_" typ_id +let reg_case_name typ_id = prepend_id "R_" typ_id +let state_field_name typ_id = append_id typ_id "_s" +let ref_name reg = append_id reg "_ref" +let add_reg_typ env (typ_map, regs_map) (typ, id, has_init) = + let typ_id = State.id_of_regtyp IdSet.empty typ in + (Bindings.add typ_id typ typ_map, Bindings.update typ_id (opt_cons id) regs_map) + +let register_enums registers = + separate hardline + [ + string "inductive Register : Type where"; + separate_map hardline (fun (_, id, _) -> string " | " ^^ doc_id_ctor id) registers; + string " deriving DecidableEq, Hashable"; + string "open Register"; + empty; + ] + +let type_enum ctx registers = + separate hardline + [ + string "abbrev RegisterType : Register → Type"; + separate_map hardline + (fun (typ, id, _) -> string " | ." ^^ doc_id_ctor id ^^ string " => " ^^ doc_typ ctx typ) + registers; + empty; + ] + +let doc_reg_info env registers = + let bare_ctx = initial_context env in + separate hardline + [ + register_enums registers; + type_enum bare_ctx registers; + string "abbrev SailM := PreSailM RegisterType"; + empty; + empty; + ] + let pp_ast_lean (env : Type_check.env) ({ defs; _ } as ast : Libsail.Type_check.typed_ast) o = let defs = remove_imports defs 0 in - let output : document = separate_map empty (doc_def (initial_context env)) defs in - print o output; + let regs = State.find_registers defs in + let register_refs = match regs with [] -> empty | _ -> doc_reg_info env regs in + let types, fundefs = doc_defs (initial_context env) defs in + print o (types ^^ register_refs ^^ fundefs); () diff --git a/test/lean/bitfield.expected.lean b/test/lean/bitfield.expected.lean index 72d994159..96f086d2b 100644 --- a/test/lean/bitfield.expected.lean +++ b/test/lean/bitfield.expected.lean @@ -4,6 +4,16 @@ open Sail def cr_type := (BitVec 8) +inductive Register : Type where + | R + deriving DecidableEq, Hashable +open Register + +abbrev RegisterType : Register → Type + | .R => (BitVec 8) + +abbrev SailM := PreSailM RegisterType + def undefined_cr_type (lit : Unit) : SailM (BitVec 8) := do sorry @@ -16,8 +26,9 @@ def _get_cr_type_bits (v : (BitVec 8)) : (BitVec 8) := def _update_cr_type_bits (v : (BitVec 8)) (x : (BitVec 8)) : (BitVec 8) := (Sail.BitVec.updateSubrange v (HSub.hSub 8 1) 0 x) -def _set_cr_type_bits (r_ref : RegisterRef Unit Unit (BitVec 8)) (v : (BitVec 8)) : SailM Unit := do - sorry +def _set_cr_type_bits (r_ref : RegisterRef RegisterType (BitVec 8)) (v : (BitVec 8)) : SailM Unit := do + let r ← (reg_deref r_ref) + writeRegRef r_ref (_update_cr_type_bits r v) def _get_cr_type_CR0 (v : (BitVec 8)) : (BitVec 4) := (Sail.BitVec.extractLsb v 7 4) @@ -25,8 +36,9 @@ def _get_cr_type_CR0 (v : (BitVec 8)) : (BitVec 4) := def _update_cr_type_CR0 (v : (BitVec 8)) (x : (BitVec 4)) : (BitVec 8) := (Sail.BitVec.updateSubrange v 7 4 x) -def _set_cr_type_CR0 (r_ref : RegisterRef Unit Unit (BitVec 8)) (v : (BitVec 4)) : SailM Unit := do - sorry +def _set_cr_type_CR0 (r_ref : RegisterRef RegisterType (BitVec 8)) (v : (BitVec 4)) : SailM Unit := do + let r ← (reg_deref r_ref) + writeRegRef r_ref (_update_cr_type_CR0 r v) def _get_cr_type_CR1 (v : (BitVec 8)) : (BitVec 2) := (Sail.BitVec.extractLsb v 3 2) @@ -34,8 +46,9 @@ def _get_cr_type_CR1 (v : (BitVec 8)) : (BitVec 2) := def _update_cr_type_CR1 (v : (BitVec 8)) (x : (BitVec 2)) : (BitVec 8) := (Sail.BitVec.updateSubrange v 3 2 x) -def _set_cr_type_CR1 (r_ref : RegisterRef Unit Unit (BitVec 8)) (v : (BitVec 2)) : SailM Unit := do - sorry +def _set_cr_type_CR1 (r_ref : RegisterRef RegisterType (BitVec 8)) (v : (BitVec 2)) : SailM Unit := do + let r ← (reg_deref r_ref) + writeRegRef r_ref (_update_cr_type_CR1 r v) def _get_cr_type_CR3 (v : (BitVec 8)) : (BitVec 2) := (Sail.BitVec.extractLsb v 1 0) @@ -43,8 +56,9 @@ def _get_cr_type_CR3 (v : (BitVec 8)) : (BitVec 2) := def _update_cr_type_CR3 (v : (BitVec 8)) (x : (BitVec 2)) : (BitVec 8) := (Sail.BitVec.updateSubrange v 1 0 x) -def _set_cr_type_CR3 (r_ref : RegisterRef Unit Unit (BitVec 8)) (v : (BitVec 2)) : SailM Unit := do - sorry +def _set_cr_type_CR3 (r_ref : RegisterRef RegisterType (BitVec 8)) (v : (BitVec 2)) : SailM Unit := do + let r ← (reg_deref r_ref) + writeRegRef r_ref (_update_cr_type_CR3 r v) def _get_cr_type_GT (v : (BitVec 8)) : (BitVec 1) := (Sail.BitVec.extractLsb v 6 6) @@ -52,8 +66,9 @@ def _get_cr_type_GT (v : (BitVec 8)) : (BitVec 1) := def _update_cr_type_GT (v : (BitVec 8)) (x : (BitVec 1)) : (BitVec 8) := (Sail.BitVec.updateSubrange v 6 6 x) -def _set_cr_type_GT (r_ref : RegisterRef Unit Unit (BitVec 8)) (v : (BitVec 1)) : SailM Unit := do - sorry +def _set_cr_type_GT (r_ref : RegisterRef RegisterType (BitVec 8)) (v : (BitVec 1)) : SailM Unit := do + let r ← (reg_deref r_ref) + writeRegRef r_ref (_update_cr_type_GT r v) def _get_cr_type_LT (v : (BitVec 8)) : (BitVec 1) := (Sail.BitVec.extractLsb v 7 7) @@ -61,9 +76,10 @@ def _get_cr_type_LT (v : (BitVec 8)) : (BitVec 1) := def _update_cr_type_LT (v : (BitVec 8)) (x : (BitVec 1)) : (BitVec 8) := (Sail.BitVec.updateSubrange v 7 7 x) -def _set_cr_type_LT (r_ref : RegisterRef Unit Unit (BitVec 8)) (v : (BitVec 1)) : SailM Unit := do - sorry +def _set_cr_type_LT (r_ref : RegisterRef RegisterType (BitVec 8)) (v : (BitVec 1)) : SailM Unit := do + let r ← (reg_deref r_ref) + writeRegRef r_ref (_update_cr_type_LT r v) -def initialize_registers : Unit := - () +def initialize_registers : SailM Unit := do + writeReg R (undefined_cr_type ()) diff --git a/test/lean/bitfield.sail b/test/lean/bitfield.sail index 2969f5927..70b259e69 100644 --- a/test/lean/bitfield.sail +++ b/test/lean/bitfield.sail @@ -9,3 +9,5 @@ bitfield cr_type : bits(8) = { CR1 : 3 .. 2, CR3 : 1 .. 0 } + +register R : cr_type diff --git a/test/lean/registers.expected.lean b/test/lean/registers.expected.lean new file mode 100644 index 000000000..145c22e1b --- /dev/null +++ b/test/lean/registers.expected.lean @@ -0,0 +1,36 @@ +import Out.Sail.Sail + +open Sail + +inductive Register : Type where + | BIT + | NAT + | BOOL + | INT + | R1 + | R0 + deriving DecidableEq, Hashable +open Register + +abbrev RegisterType : Register → Type + | .BIT => (BitVec 1) + | .NAT => Nat + | .BOOL => Bool + | .INT => Int + | .R1 => (BitVec 64) + | .R0 => (BitVec 64) + +abbrev SailM := PreSailM RegisterType + +def test : SailM Int := do + writeReg INT (HAdd.hAdd (← readReg INT) 1) + readReg INT + +def initialize_registers : SailM Unit := do + writeReg R0 sorry + writeReg R1 sorry + writeReg INT sorry + writeReg BOOL sorry + writeReg NAT sorry + writeReg BIT sorry + diff --git a/test/lean/registers.sail b/test/lean/registers.sail new file mode 100644 index 000000000..43907ff6c --- /dev/null +++ b/test/lean/registers.sail @@ -0,0 +1,15 @@ +default Order dec + +$include + +register R0 : bits(64) +register R1 : bits(64) +register INT : int +register BOOL : bool +register NAT : nat +register BIT : bit + +function test () -> int = { + INT = INT + 1; + INT +} \ No newline at end of file