Skip to content

Commit

Permalink
Adding a wrapper to the lean backend to be able to handle more functi…
Browse files Browse the repository at this point in the history
…ons (#801)

This adds a wrapper to the lean backend, which allows handling additional bitvector operations.
The implemented operations are: length, signExtend, zeroExtend, truncate and truncateLSB.

The wrapper functions are in the namespace Sail.BitVec to avoid collisions.

The remaining unimplemented operations require translating type conditions into Lean, which is not handled for now.
  • Loading branch information
lfrenot authored Dec 3, 2024
1 parent 161caca commit f3da818
Show file tree
Hide file tree
Showing 13 changed files with 118 additions and 11 deletions.
18 changes: 15 additions & 3 deletions lib/vector.sail
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@ function neq_bits(x, y) = not_bool(eq_bits(x, y))

overload operator != = {neq_bits}

val bitvector_length = pure {coq: "length_mword", _: "length"} : forall 'n. bits('n) -> int('n)
val bitvector_length = pure {
coq: "length_mword",
lean: "Sail.BitVec.length",
_: "length"
} : forall 'n. bits('n) -> int('n)

val vector_length = pure {
ocaml: "length",
Expand All @@ -100,9 +104,15 @@ val print_bits = pure "print_bits" : forall 'n. (string, bits('n)) -> unit
$[sv_module { stderr = true }]
val prerr_bits = pure "prerr_bits" : forall 'n. (string, bits('n)) -> unit

val sail_sign_extend = pure "sign_extend" : forall 'n 'm, 'm >= 'n. (bits('n), int('m)) -> bits('m)
val sail_sign_extend = pure {
lean: "Sail.BitVec.signExtend",
_: "sign_extend"
} : forall 'n 'm, 'm >= 'n. (bits('n), int('m)) -> bits('m)

val sail_zero_extend = pure "zero_extend" : forall 'n 'm, 'm >= 'n. (bits('n), int('m)) -> bits('m)
val sail_zero_extend = pure {
lean: "Sail.BitVec.zeroExtend",
_: "zero_extend"
} : forall 'n 'm, 'm >= 'n. (bits('n), int('m)) -> bits('m)

/*!
THIS`(v, n)` truncates `v`, keeping only the _least_ significant `n` bits.
Expand All @@ -112,6 +122,7 @@ val truncate = pure {
interpreter: "vector_truncate",
lem: "vector_truncate",
coq: "vector_truncate",
lean: "Sail.BitVec.truncate",
_: "sail_truncate"
} : forall 'm 'n, 'm >= 0 & 'm <= 'n. (bits('n), int('m)) -> bits('m)

Expand All @@ -123,6 +134,7 @@ val truncateLSB = pure {
interpreter: "vector_truncateLSB",
lem: "vector_truncateLSB",
coq: "vector_truncateLSB",
lean: "Sail.BitVec.truncateLSB",
_: "sail_truncateLSB"
} : forall 'm 'n, 'm >= 0 & 'm <= 'n. (bits('n), int('m)) -> bits('m)

Expand Down
5 changes: 4 additions & 1 deletion src/bin/dune
Original file line number Diff line number Diff line change
Expand Up @@ -239,4 +239,7 @@
src/gen_lib/sail2_undefined_concurrency_interface.lem)
(%{workspace_root}/src/gen_lib/sail2_values.lem
as
src/gen_lib/sail2_values.lem)))
src/gen_lib/sail2_values.lem)
(%{workspace_root}/src/sail_lean_backend/Sail/sail.lean
as
src/sail_lean_backend/Sail/sail.lean)))
19 changes: 19 additions & 0 deletions src/sail_lean_backend/Sail/sail.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
namespace Sail
namespace BitVec

def length {w: Nat} (_: BitVec w): Nat := w

def signExtend {w: Nat} (x: BitVec w) (w': Nat) : BitVec w' :=
x.signExtend w'

def zeroExtend {w: Nat} (x: BitVec w) (w': Nat) : BitVec w' :=
x.zeroExtend w'

def truncate {w: Nat} (x: BitVec w) (w': Nat) : BitVec w' :=
x.truncate w'

def truncateLSB {w: Nat} (x: BitVec w) (w': Nat) : BitVec w' :=
x.extractLsb' 0 w'

end BitVec
end Sail
2 changes: 2 additions & 0 deletions src/sail_lean_backend/pretty_print_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ let rec doc_typ (Typ_aux (t, _) as typ) =
| Typ_id (Id_aux (Id "int", _)) -> string "Int"
| Typ_id (Id_aux (Id "bool", _)) -> string "Bool"
| Typ_id (Id_aux (Id "bit", _)) -> parens (string "BitVec 1")
| Typ_id (Id_aux (Id "nat", _)) -> string "Nat"
| Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp m, _)]) -> string "BitVec " ^^ doc_nexp m
| Typ_tuple ts -> parens (separate_map (space ^^ string "×" ^^ space) doc_typ ts)
| Typ_id (Id_aux (Id id, _)) -> string id
Expand Down Expand Up @@ -199,5 +200,6 @@ let rec remove_imports (defs : (Libsail.Type_check.tannot, Libsail.Type_check.en
let pp_ast_lean ({ defs; _ } as ast : Libsail.Type_check.typed_ast) o =
let defs = remove_imports defs 0 in
let output : document = separate_map empty doc_def defs in
output_string o "import Sail.sail\n\n";
print o output;
()
16 changes: 9 additions & 7 deletions src/sail_lean_backend/sail_plugin_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ let lean_rewrites =
("attach_effects", []);
]

let create_lake_project (out_name : string) =
let create_lake_project (out_name : string) default_sail_dir =
(* Change the base directory if the option '--lean-output-dir' is set *)
let base_dir = match !opt_lean_output_dir with Some dir -> dir | None -> "." in
let project_dir = Filename.concat base_dir out_name in
Expand All @@ -164,22 +164,24 @@ let create_lake_project (out_name : string) =
let out_name_camel = Libsail.Util.to_upper_camel_case out_name in
let lakefile = open_out (Filename.concat project_dir "lakefile.toml") in
output_string lakefile
("name = \"" ^ out_name ^ "\"\ndefaultTargets = [\"" ^ out_name_camel ^ "\"]\n\n[[lean_lib]]\nname = \""
^ out_name_camel ^ "\""
("name = \"" ^ out_name ^ "\"\ndefaultTargets = [\"" ^ out_name_camel
^ "\"]\n\n[[lean_lib]]\nname = \"Sail\"\n\n[[lean_lib]]\nname = \"" ^ out_name_camel ^ "\""
);
close_out lakefile;
let sail_dir = Reporting.get_sail_dir default_sail_dir in
let _ = Unix.system ("cp -r " ^ sail_dir ^ "/src/sail_lean_backend/Sail " ^ project_dir) in
let project_main = open_out (Filename.concat project_dir (out_name_camel ^ ".lean")) in
project_main

let output (out_name : string) ast =
let project_main = create_lake_project out_name in
let output (out_name : string) ast default_sail_dir =
let project_main = create_lake_project out_name default_sail_dir in
(* Uncomment for debug output of the Sail code after the rewrite passes *)
(* Pretty_print_sail.output_ast stdout (Type_check.strip_ast ast); *)
Pretty_print_lean.pp_ast_lean ast project_main;
close_out project_main

let lean_target out_name { ctx; ast; effect_info; env; _ } =
let lean_target out_name { default_sail_dir; ctx; ast; effect_info; env; _ } =
let out_name = match out_name with Some f -> f | None -> "out" in
output out_name ast
output out_name ast default_sail_dir

let _ = Target.register ~name:"lean" ~options:lean_options ~rewrites:lean_rewrites ~asserts_termination:true lean_target
23 changes: 23 additions & 0 deletions test/lean/bitvec_operation.expected.lean
Original file line number Diff line number Diff line change
@@ -1,9 +1,26 @@
import Sail.sail

def bitvector_eq (x : BitVec 16) (y : BitVec 16) : Bool :=
(Eq x y)

def bitvector_neq (x : BitVec 16) (y : BitVec 16) : Bool :=
(Ne x y)

def bitvector_len (x : BitVec 16) : Nat :=
(Sail.BitVec.length x)

def bitvector_sign_extend (x : BitVec 16) : BitVec 32 :=
(Sail.BitVec.signExtend x 32)

def bitvector_zero_extend (x : BitVec 16) : BitVec 32 :=
(Sail.BitVec.zeroExtend x 32)

def bitvector_truncate (x : BitVec 32) : BitVec 16 :=
(Sail.BitVec.truncate x 16)

def bitvector_truncateLSB (x : BitVec 32) : BitVec 16 :=
(Sail.BitVec.truncateLSB x 16)

def bitvector_append (x : BitVec 16) (y : BitVec 16) : BitVec 32 :=
(BitVec.append x y)

Expand All @@ -25,6 +42,12 @@ def bitvector_or (x : BitVec 16) (y : BitVec 16) : BitVec 16 :=
def bitvector_xor (x : BitVec 16) (y : BitVec 16) : BitVec 16 :=
(HXor.hXor x y)

def bitvector_unsigned (x : BitVec 16) : Nat :=
(BitVec.toNat x)

def bitvector_signed (x : BitVec 16) : Int :=
(BitVec.toInt x)

def initialize_registers : Unit :=
()

34 changes: 34 additions & 0 deletions test/lean/bitvec_operation.sail
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,31 @@ function bitvector_neq(x, y) = {
x != y
}

val bitvector_len : bits(16) -> nat
function bitvector_len(x) = {
length (x)
}

val bitvector_sign_extend : bits(16) -> bits(32)
function bitvector_sign_extend(x) = {
sail_sign_extend (x, 32)
}

val bitvector_zero_extend : bits(16) -> bits(32)
function bitvector_zero_extend(x) = {
sail_zero_extend (x, 32)
}

val bitvector_truncate : bits(32) -> bits(16)
function bitvector_truncate(x) = {
truncate (x, 16)
}

val bitvector_truncateLSB : bits(32) -> bits(16)
function bitvector_truncateLSB(x) = {
truncateLSB (x, 16)
}

val bitvector_append : (bits(16), bits(16)) -> bits(32)
function bitvector_append(x, y) = {
append (x, y)
Expand Down Expand Up @@ -47,3 +72,12 @@ function bitvector_xor(x, y) = {
xor_vec (x, y)
}

val bitvector_unsigned : bits(16) -> nat
function bitvector_unsigned(x) = {
unsigned (x)
}

val bitvector_signed : bits(16) -> int
function bitvector_signed(x) = {
signed (x)
}
2 changes: 2 additions & 0 deletions test/lean/enum.expected.lean
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import Sail.sail

inductive E where | A | B | C
deriving Inhabited

Expand Down
2 changes: 2 additions & 0 deletions test/lean/extern.expected.lean
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import Sail.sail

def extern_add : Int :=
(Int.add 5 4)

Expand Down
2 changes: 2 additions & 0 deletions test/lean/extern_bitvec.expected.lean
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import Sail.sail

def extern_const : BitVec 64 :=
(0xFFFF000012340000 : BitVec 64)

Expand Down
2 changes: 2 additions & 0 deletions test/lean/let.expected.lean
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import Sail.sail

def foo : BitVec 16 :=
let z := (HOr.hOr (0xFFFF : BitVec 16) (0xABCD : BitVec 16))
(HAnd.hAnd (0x0000 : BitVec 16) z)
Expand Down
2 changes: 2 additions & 0 deletions test/lean/trivial.expected.lean
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import Sail.sail

def foo (y : Unit) : Unit :=
y

Expand Down
2 changes: 2 additions & 0 deletions test/lean/tuples.expected.lean
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import Sail.sail

def tuple1 : (Int × Int × (BitVec 2 × Unit)) :=
(3, 5, ((0b10 : BitVec 2), ()))

Expand Down

0 comments on commit f3da818

Please sign in to comment.