Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lean: add support for register definitions #894

Merged
merged 24 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/lib/state.ml
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,33 @@ let register_refs_coq doc_id coq_record_update env registers =
in
separate hardline [generic_convs; refs; getters_setters]

let register_refs_lean doc_id doc_typ registers =
let generic_convs = separate_map hardline string [""; "variable [MonadReg]"; ""; "open MonadReg"; ""] in
let register_ref (typ, id, _) =
let idd = doc_id id in
let typp = doc_typ typ in
concat
[
string " set_";
idd;
space;
colon;
space;
typp;
string " -> SailM Unit";
hardline;
string " get_";
idd;
space;
colon;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Space needed before the colon

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed 👍

space;
string "SailM ";
typp;
]
in
let refs = separate_map hardline register_ref registers in
separate hardline [string "class MonadReg where"; refs; generic_convs]

let generate_regstate_defs ctx env ast =
let defs = ast.defs in
let registers = find_registers defs in
Expand Down
96 changes: 91 additions & 5 deletions src/sail_lean_backend/pretty_print_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,82 @@ let string_of_exp_con (E_aux (e, _)) =
| E_vector _ -> "E_vector"
| E_let _ -> "E_let"

let string_of_pat_con (P_aux (p, _)) =
match p with
| P_app _ -> "P_app"
| P_wild -> "P_wild"
| P_lit _ -> "P_lit"
| P_or _ -> "P_or"
| P_not _ -> "P_not"
| P_as _ -> "P_as"
| P_typ _ -> "P_typ"
| P_id _ -> "P_id"
| P_var _ -> "P_var"
| P_vector _ -> "P_vector"
| P_vector_concat _ -> "P_vector_concat"
| P_vector_subrange _ -> "P_vector_subrange"
| P_tuple _ -> "P_tuple"
| P_list _ -> "P_list"
| P_cons _ -> "P_cons"
| P_string_append _ -> "P_string_append"
| P_struct _ -> "P_struct"

let rec doc_pat ctxt apat_needed (P_aux (p, (l, annot)) as pat) =
let env = env_of_annot (l, annot) in
let typ = Env.expand_synonyms env (typ_of_annot (l, annot)) in
match p with
| P_typ (ptyp, p) ->
let doc_p = doc_pat ctxt true p in
doc_p
| P_id id -> doc_id_ctor id
| _ -> failwith ("Pattern " ^ string_of_pat_con pat ^ " " ^ string_of_pat pat ^ " not translatable yet.")

(* Copied from the Coq PP *)
let rebind_cast_pattern_vars pat typ exp =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Lean actually need this? (It's to deal with a limitation of Coq's let-binding patterns.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lean can pattern match in let-bindings if there's a single constructor, is that enough?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's roughly what Coq allows, except with the limitation that you can't have type annotations deep inside the pattern.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, now that I think about it, it may also have been to ensure that sufficient bitvector casts are inserted (e.g., when you have bits(8 * 'n) and need bits('n * 8).

let rec aux pat typ =
match (pat, typ) with
| P_aux (P_typ (target_typ, P_aux (P_id id, (l, ann))), _), source_typ when not (is_enum (env_of exp) id) ->
if Typ.compare target_typ source_typ == 0 then []
else (
let l = Parse_ast.Generated l in
let cast_annot = Type_check.replace_typ source_typ ann in
let e_annot = Type_check.mk_tannot (env_of exp) source_typ in
[LB_aux (LB_val (pat, E_aux (E_id id, (l, e_annot))), (l, ann))]
)
| P_aux (P_tuple pats, _), Typ_aux (Typ_tuple typs, _) -> List.concat (List.map2 aux pats typs)
| _ -> []
in
let add_lb (E_aux (_, ann) as exp) lb = E_aux (E_let (lb, exp), ann) in
(* Don't introduce new bindings at the top-level, we'd just go into a loop. *)
let lbs =
match (pat, typ) with
| P_aux (P_tuple pats, _), Typ_aux (Typ_tuple typs, _) -> List.concat (List.map2 aux pats typs)
| _ -> []
in
List.fold_left add_lb exp lbs

let rec doc_exp ctx (E_aux (e, (l, annot)) as full_exp) =
let env = env_of_tannot annot in
match e with
| E_id id -> string (string_of_id id) (* TODO replace by a translating via a binding map *)
| E_lit l -> doc_lit l
| E_app (Id_aux (Id "internal_pick", _), _) ->
string "sorry" (* TODO replace by actual implementation of internal_pick *)
| E_internal_plet _ -> string "sorry" (* TODO replace by actual implementation of internal_plet *)
| E_internal_plet (pat, e1, e2) ->
(* doc_exp ctxt e1 ^^ hardline ^^ doc_exp ctxt e2 *)
let e0 = doc_pat ctx false pat in
let e1_pp = doc_exp ctx e1 in
let e2' = rebind_cast_pattern_vars pat (typ_of e1) e2 in
let e2_pp = doc_exp ctx e2' in
(* infix 0 1 middle e1_pp e2_pp *)
let e0_pp =
begin
match pat with
| P_aux (P_typ (_, P_aux (P_wild, _)), _) -> string ""
| _ -> flow (break 1) [string "let"; e0; string ":="] ^^ space
end
in
nest 2 (e0_pp ^^ e1_pp) ^^ hardline ^^ e2_pp
| E_app (f, args) ->
let d_id =
if Env.is_extern f env "lean" then string (Env.get_extern f env "lean")
Expand All @@ -280,7 +348,13 @@ let rec doc_exp ctx (E_aux (e, (l, annot)) as full_exp) =
let d_args = List.map (doc_exp ctx) args in
nest 2 (parens (flow (break 1) (d_id :: d_args)))
| E_vector vals -> failwith "vector found"
| E_typ (typ, e) -> parens (separate space [doc_exp ctx e; colon; doc_typ ctx typ])
| E_typ (typ, e) -> (
match e with
| E_aux (E_assign _, _) -> doc_exp ctx e
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might try and get rid of these silly unit type annotations on assignments when they're generated, as they affect several backends.

| E_aux (E_app (Id_aux (Id "internal_pick", _), _), _) ->
string "return " ^^ nest 7 (parens (flow (break 1) [doc_exp ctx e; colon; doc_typ ctx typ]))
| _ -> parens (flow (break 1) [doc_exp ctx e; colon; doc_typ ctx typ])
)
| E_tuple es -> parens (separate_map (comma ^^ space) (doc_exp ctx) es)
| E_let (LB_aux (LB_val (lpat, lexp), _), e) ->
let id =
Expand All @@ -297,6 +371,13 @@ let rec doc_exp ctx (E_aux (e, (l, annot)) as full_exp) =
| E_struct_update (exp, fexps) ->
let args = List.map (doc_fexp ctx) fexps in
braces (space ^^ doc_exp ctx exp ^^ string " with " ^^ separate (comma ^^ space) args ^^ space)
| E_assign ((LE_aux (le_act, tannot) as le), e) -> (
match le_act with
| LE_id id | LE_typ (_, id) -> string "set_" ^^ doc_id_ctor id ^^ space ^^ doc_exp ctx e
| LE_deref e -> string "sorry /- deref -/"
| _ -> failwith ("assign " ^ string_of_lexp le ^ "not implemented yet")
)
| E_internal_return e -> nest 2 (string "return" ^^ space ^^ nest 5 (doc_exp ctx e))
| _ -> failwith ("Expression " ^ string_of_exp_con full_exp ^ " " ^ string_of_exp full_exp ^ " not translatable yet.")

and doc_fexp ctx (FE_aux (FE_fexp (field, exp), _)) = doc_id_ctor field ^^ string " := " ^^ doc_exp ctx exp
Expand Down Expand Up @@ -363,8 +444,7 @@ let doc_funcl_body (FCL_aux (FCL_funcl (id, pexp), annot)) =
let env = env_of_tannot (snd annot) in
let ctx = initial_context env in
let _, _, exp, _ = destruct_pexp pexp in
let is_monadic = effectful (effect_of exp) in
if is_monadic then nest 2 (flow (break 1) [string "return"; doc_exp ctx exp]) else doc_exp ctx exp
doc_exp (initial_context env) exp

let doc_funcl ctx funcl =
let comment, signature, env = doc_funcl_init funcl in
Expand Down Expand Up @@ -425,6 +505,12 @@ let rec remove_imports (defs : (Libsail.Type_check.tannot, Libsail.Type_check.en

let pp_ast_lean (env : Type_check.env) ({ defs; _ } as ast : Libsail.Type_check.typed_ast) o =
let defs = remove_imports defs 0 in
let regs = State.find_registers defs in
let register_refs =
match regs with
| [] -> empty
| _ -> State.register_refs_lean doc_id_ctor (doc_typ (initial_context env)) regs ^^ hardline
in
let output : document = separate_map empty (doc_def (initial_context env)) defs in
print o output;
print o (register_refs ^^ output);
()
20 changes: 13 additions & 7 deletions test/lean/bitfield.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ open Sail
def cr_type := (BitVec 8)

def undefined_cr_type (lit : Unit) : SailM (BitVec 8) :=
return ((undefined_bitvector 8) : (BitVec 8))
((undefined_bitvector 8) : (BitVec 8))

def Mk_cr_type (v : (BitVec 8)) : (BitVec 8) :=
v
Expand All @@ -17,7 +17,8 @@ def _update_cr_type_bits (v : (BitVec 8)) (x : (BitVec 8)) : (BitVec 8) :=
(Sail.BitVec.updateSubrange v (HSub.hSub 8 1) 0 x)

def _set_cr_type_bits (r_ref : RegisterRef Unit Unit (BitVec 8)) (v : (BitVec 8)) : SailM Unit :=
return sorry
let r := (reg_deref r_ref)
sorry /- deref -/

def _get_cr_type_CR0 (v : (BitVec 8)) : (BitVec 4) :=
(Sail.BitVec.extractLsb v 7 4)
Expand All @@ -26,7 +27,8 @@ def _update_cr_type_CR0 (v : (BitVec 8)) (x : (BitVec 4)) : (BitVec 8) :=
(Sail.BitVec.updateSubrange v 7 4 x)

def _set_cr_type_CR0 (r_ref : RegisterRef Unit Unit (BitVec 8)) (v : (BitVec 4)) : SailM Unit :=
return sorry
let r := (reg_deref r_ref)
sorry /- deref -/

def _get_cr_type_CR1 (v : (BitVec 8)) : (BitVec 2) :=
(Sail.BitVec.extractLsb v 3 2)
Expand All @@ -35,7 +37,8 @@ def _update_cr_type_CR1 (v : (BitVec 8)) (x : (BitVec 2)) : (BitVec 8) :=
(Sail.BitVec.updateSubrange v 3 2 x)

def _set_cr_type_CR1 (r_ref : RegisterRef Unit Unit (BitVec 8)) (v : (BitVec 2)) : SailM Unit :=
return sorry
let r := (reg_deref r_ref)
sorry /- deref -/

def _get_cr_type_CR3 (v : (BitVec 8)) : (BitVec 2) :=
(Sail.BitVec.extractLsb v 1 0)
Expand All @@ -44,7 +47,8 @@ def _update_cr_type_CR3 (v : (BitVec 8)) (x : (BitVec 2)) : (BitVec 8) :=
(Sail.BitVec.updateSubrange v 1 0 x)

def _set_cr_type_CR3 (r_ref : RegisterRef Unit Unit (BitVec 8)) (v : (BitVec 2)) : SailM Unit :=
return sorry
let r := (reg_deref r_ref)
sorry /- deref -/

def _get_cr_type_GT (v : (BitVec 8)) : (BitVec 1) :=
(Sail.BitVec.extractLsb v 6 6)
Expand All @@ -53,7 +57,8 @@ def _update_cr_type_GT (v : (BitVec 8)) (x : (BitVec 1)) : (BitVec 8) :=
(Sail.BitVec.updateSubrange v 6 6 x)

def _set_cr_type_GT (r_ref : RegisterRef Unit Unit (BitVec 8)) (v : (BitVec 1)) : SailM Unit :=
return sorry
let r := (reg_deref r_ref)
sorry /- deref -/

def _get_cr_type_LT (v : (BitVec 8)) : (BitVec 1) :=
(Sail.BitVec.extractLsb v 7 7)
Expand All @@ -62,7 +67,8 @@ def _update_cr_type_LT (v : (BitVec 8)) (x : (BitVec 1)) : (BitVec 8) :=
(Sail.BitVec.updateSubrange v 7 7 x)

def _set_cr_type_LT (r_ref : RegisterRef Unit Unit (BitVec 8)) (v : (BitVec 1)) : SailM Unit :=
return sorry
let r := (reg_deref r_ref)
sorry /- deref -/

def initialize_registers : Unit :=
()
Expand Down
16 changes: 16 additions & 0 deletions test/lean/registers.expected.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import Out.Sail.Sail

open Sail

class MonadReg where
set_R0 : (BitVec 64) -> SailM Unit
get_R0 : SailM (BitVec 64)

lfrenot marked this conversation as resolved.
Show resolved Hide resolved
variable [MonadReg]

open MonadReg

def initialize_registers : SailM Unit :=
let w__0 := (undefined_bitvector 64)
set_R0 w__0

5 changes: 5 additions & 0 deletions test/lean/registers.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
default Order dec

$include <prelude.sail>

register R0 : bits(64)
5 changes: 4 additions & 1 deletion test/lean/struct.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ structure My_struct where
field2 : (BitVec 1)

def undefined_My_struct (lit : Unit) : SailM My_struct :=
return sorry
let w__0 := (undefined_int ())
let w__1 := (undefined_bit ())
return { field1 := w__0
field2 := w__1 }

def struct_field2 (s : My_struct) : (BitVec 1) :=
s.field2
Expand Down
Loading