Skip to content

Commit

Permalink
Update mono tests and mono_rewrites library
Browse files Browse the repository at this point in the history
- remove old effect annotations
- add pure annotations for relevant builtins
- replace old set type notation
- add non-negative bitvector size constraints where necessary
  • Loading branch information
bacam committed Dec 10, 2024
1 parent bd6fb0b commit c050b3d
Show file tree
Hide file tree
Showing 15 changed files with 64 additions and 62 deletions.
58 changes: 29 additions & 29 deletions lib/mono_rewrites.sail
Original file line number Diff line number Diff line change
Expand Up @@ -54,58 +54,58 @@ $include <vector_dec.sail>

/* External definitions not in the usual asl prelude */

val extzv = {lem: "extz_vec"} : forall 'n 'm, 'm >= 0. (implicit('m), bitvector('n, dec)) -> bitvector('m, dec) effect pure
val extzv = pure {lem: "extz_vec"} : forall 'n 'm, 'm >= 0. (implicit('m), bitvector('n, dec)) -> bitvector('m, dec)
function extzv(m, v) = {
if m < 'n then truncate(v, m) else sail_zero_extend(v, m)
}

val extsv = {lem: "exts_vec"} : forall 'n 'm, 'm >= 0. (implicit('m), bitvector('n, dec)) -> bitvector('m, dec) effect pure
val extsv = pure {lem: "exts_vec"} : forall 'n 'm, 'm >= 0. (implicit('m), bitvector('n, dec)) -> bitvector('m, dec)
function extsv(m, v) = {
if m < 'n then truncate(v, m) else sail_sign_extend(v, m)
}

/* This is generated internally to deal with case splits which reveal the size
of a bitvector */
val bitvector_cast_in = "zeroExtend" : forall 'n. bits('n) -> bits('n) effect pure
val bitvector_cast_out = "zeroExtend" : forall 'n. bits('n) -> bits('n) effect pure
val bitvector_cast_in = pure "zeroExtend" : forall 'n. bits('n) -> bits('n)
val bitvector_cast_out = pure "zeroExtend" : forall 'n. bits('n) -> bits('n)

/* Builtins for the rewrites */
val string_of_bits_subrange = pure "string_of_bits_subrange" : forall 'n. (bits('n), int, int) -> string

/* Definitions for the rewrites */

val is_zero_subrange : forall 'n, 'n >= 0.
(bits('n), int, int) -> bool effect pure
(bits('n), int, int) -> bool

function is_zero_subrange (xs, i, j) = {
(xs & slice_mask(j, i-j+1)) == extzv([bitzero] : bits(1))
}

val is_zeros_slice : forall 'n, 'n >= 0.
(bits('n), int, int) -> bool effect pure
(bits('n), int, int) -> bool

function is_zeros_slice (xs, i, l) = {
(xs & slice_mask(i, l)) == extzv([bitzero] : bits(1))
}

val is_ones_subrange : forall 'n, 'n >= 0.
(bits('n), int, int) -> bool effect pure
(bits('n), int, int) -> bool

function is_ones_subrange (xs, i, j) = {
let m : bits('n) = slice_mask(j,i-j+1) in
(xs & m) == m
}

val is_ones_slice : forall 'n, 'n >= 0.
(bits('n), int, int) -> bool effect pure
(bits('n), int, int) -> bool

function is_ones_slice (xs, i, j) = {
let m : bits('n) = slice_mask(i,j) in
(xs & m) == m
}

val slice_slice_concat : forall 'n 'm 'r, 'n >= 0 & 'm >= 0 & 'r >= 0.
(implicit('r), bits('n), int, int, bits('m), int, int) -> bits('r) effect pure
(implicit('r), bits('n), int, int, bits('m), int, int) -> bits('r)

function slice_slice_concat (r, xs, i, l, ys, i', l') = {
let xs = sail_shiftright(xs & slice_mask(i,l), i) in
Expand All @@ -114,23 +114,23 @@ function slice_slice_concat (r, xs, i, l, ys, i', l') = {
}

val slice_zeros_concat : forall 'n 'p 'q, 'n >= 0 & 'p + 'q >= 0.
(bits('n), int, atom('p), atom('q)) -> bits('p + 'q) effect pure
(bits('n), int, atom('p), atom('q)) -> bits('p + 'q)

function slice_zeros_concat (xs, i, l, l') = {
let xs = sail_shiftright(xs & slice_mask(i,l), i) in
sail_shiftleft(extzv(l + l', xs), l')
}

val subrange_zeros_concat : forall 'n 'hi 'lo 'q, 'n >= 0 & 'hi - 'lo + 1 + 'q >= 0.
(bits('n), atom('hi), atom('lo), atom('q)) -> bits('hi - 'lo + 1 + 'q) effect pure
(bits('n), atom('hi), atom('lo), atom('q)) -> bits('hi - 'lo + 1 + 'q)

function subrange_zeros_concat (xs, hi, lo, l') =
slice_zeros_concat(xs, lo, hi - lo + 1, l')

/* Assumes initial vectors are of equal size */

val subrange_subrange_eq : forall 'n, 'n >= 0.
(bits('n), int, int, bits('n), int, int) -> bool effect pure
(bits('n), int, int, bits('n), int, int) -> bool

function subrange_subrange_eq (xs, i, j, ys, i', j') = {
let xs = sail_shiftright(xs & slice_mask(j,i-j+1), j) in
Expand All @@ -139,7 +139,7 @@ function subrange_subrange_eq (xs, i, j, ys, i', j') = {
}

val subrange_subrange_concat : forall 'n 'o 'p 'm 'q 'r 's, 's >= 0 & 'n >= 0 & 'm >= 0.
(implicit('s), bits('n), atom('o), atom('p), bits('m), atom('q), atom('r)) -> bits('s) effect pure
(implicit('s), bits('n), atom('o), atom('p), bits('m), atom('q), atom('r)) -> bits('s)

function subrange_subrange_concat (s, xs, i, j, ys, i', j') = {
let xs = sail_shiftright(xs & slice_mask(j,i-j+1), j) in
Expand All @@ -148,70 +148,70 @@ function subrange_subrange_concat (s, xs, i, j, ys, i', j') = {
}

val place_subrange : forall 'n 'm, 'n >= 0 & 'm >= 0.
(implicit('m), bits('n), int, int, int) -> bits('m) effect pure
(implicit('m), bits('n), int, int, int) -> bits('m)

function place_subrange(m,xs,i,j,shift) = {
let xs = sail_shiftright(xs & slice_mask(j,i-j+1), j) in
sail_shiftleft(extzv(m, xs), shift)
}

val place_slice : forall 'n 'm, 'n >= 0 & 'm >= 0.
(implicit('m), bits('n), int, int, int) -> bits('m) effect pure
(implicit('m), bits('n), int, int, int) -> bits('m)

function place_slice(m,xs,i,l,shift) = {
let xs = sail_shiftright(xs & slice_mask(i,l), i) in
sail_shiftleft(extzv(m, xs), shift)
}

val set_slice_zeros : forall 'n, 'n >= 0.
(implicit('n), bits('n), int, int) -> bits('n) effect pure
(implicit('n), bits('n), int, int) -> bits('n)

function set_slice_zeros(n, xs, i, l) = {
let ys : bits('n) = slice_mask(n, i, l) in
xs & not_vec(ys)
}

val set_subrange_zeros : forall 'n, 'n >= 0.
(implicit('n), bits('n), int, int) -> bits('n) effect pure
(implicit('n), bits('n), int, int) -> bits('n)

function set_subrange_zeros(n, xs, hi, lo) =
set_slice_zeros(n, xs, lo, hi - lo + 1)

val zext_slice : forall 'n 'm, 'n >= 0 & 'm >= 0.
(implicit('m), bits('n), int, int) -> bits('m) effect pure
(implicit('m), bits('n), int, int) -> bits('m)

function zext_slice(m,xs,i,l) = {
let xs = sail_shiftright(xs & slice_mask(i,l), i) in
extzv(m, xs)
}

val zext_subrange : forall 'n 'm, 'n >= 0 & 'm >= 0.
(implicit('m), bits('n), int, int) -> bits('m) effect pure
(implicit('m), bits('n), int, int) -> bits('m)

function zext_subrange(m, xs, i, j) = zext_slice(m, xs, j, i - j + 1)

val sext_slice : forall 'n 'm, 'n >= 0 & 'm >= 0.
(implicit('m), bits('n), int, int) -> bits('m) effect pure
(implicit('m), bits('n), int, int) -> bits('m)

function sext_slice(m,xs,i,l) = {
let xs = sail_arith_shiftright(sail_shiftleft((xs & slice_mask(i,l)), ('n - i - l)), 'n - l) in
extsv(m, xs)
}

val sext_subrange : forall 'n 'm, 'n >= 0 & 'm >= 0.
(implicit('m), bits('n), int, int) -> bits('m) effect pure
(implicit('m), bits('n), int, int) -> bits('m)

function sext_subrange(m, xs, i, j) = sext_slice(m, xs, j, i - j + 1)

val place_slice_signed : forall 'n 'm, 'n >= 0 & 'm >= 0.
(implicit('m), bits('n), int, int, int) -> bits('m) effect pure
(implicit('m), bits('n), int, int, int) -> bits('m)

function place_slice_signed(m,xs,i,l,shift) = {
sail_shiftleft(sext_slice(m, xs, i, l), shift)
}

val place_subrange_signed : forall 'n 'm, 'n >= 0 & 'm >= 0.
(implicit('m), bits('n), int, int, int) -> bits('m) effect pure
(implicit('m), bits('n), int, int, int) -> bits('m)

function place_subrange_signed(m,xs,i,j,shift) = {
place_slice_signed(m, xs, j, i-j+1, shift)
Expand All @@ -220,7 +220,7 @@ function place_subrange_signed(m,xs,i,j,shift) = {
/* This has different names in the aarch64 prelude (UInt) and the other
preludes (unsigned). To avoid variable name clashes, we redeclare it
here with a suitably awkward name. */
val _builtin_unsigned = {
val _builtin_unsigned = pure {
ocaml: "uint",
lem: "uint",
interpreter: "uint",
Expand All @@ -232,7 +232,7 @@ val _builtin_unsigned = {
they agree on positive values. We use this here to give more precise return
types for unsigned_slice and unsigned_subrange. */

val _builtin_mod_nat = {
val _builtin_mod_nat = pure {
smt: "mod",
ocaml: "modulus",
lem: "integerMod",
Expand All @@ -242,26 +242,26 @@ val _builtin_mod_nat = {

/* Below we need the fact that 2 ^ 'n >= 0, so we axiomatise it in the return
type of pow2, as SMT solvers tend to have problems with exponentiation. */
val _builtin_pow2 = "pow2" : forall 'n, 'n >= 0. int('n) -> {'m, 'm == 2 ^ 'n & 'm >= 0. int('m)}
val _builtin_pow2 = pure "pow2" : forall 'n, 'n >= 0. int('n) -> {'m, 'm == 2 ^ 'n & 'm >= 0. int('m)}

val unsigned_slice : forall 'n 'l, 'n >= 0 & 'l >= 0.
(bits('n), int, int('l)) -> {'m, 0 <= 'm < 2 ^ 'l. int('m)} effect pure
(bits('n), int, int('l)) -> {'m, 0 <= 'm < 2 ^ 'l. int('m)}

function unsigned_slice(xs,i,l) = {
let xs = sail_shiftright(xs & slice_mask(i,l), i) in
_builtin_mod_nat(_builtin_unsigned(xs), _builtin_pow2(l))
}

val unsigned_subrange : forall 'n 'i 'j, 'n >= 0 & ('i - 'j) >= 0.
(bits('n), int('i), int('j)) -> {'m, 0 <= 'm < 2 ^ ('i - 'j + 1). int('m)} effect pure
(bits('n), int('i), int('j)) -> {'m, 0 <= 'm < 2 ^ ('i - 'j + 1). int('m)}

function unsigned_subrange(xs,i,j) = {
let xs = sail_shiftright(xs & slice_mask(j,i-j+1), i) in
_builtin_mod_nat(_builtin_unsigned(xs), _builtin_pow2(i - j + 1))
}


val zext_ones : forall 'n, 'n >= 0. (implicit('n), int) -> bits('n) effect pure
val zext_ones : forall 'n, 'n >= 0. (implicit('n), int) -> bits('n)

function zext_ones(n, m) = {
let v : bits('n) = extsv([bitone] : bits(1)) in
Expand Down
4 changes: 2 additions & 2 deletions test/mono/builtins.sail
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ $include <arith.sail>
$include <flow.sail>
$include <vector_dec.sail>

val neq_vec = {lem: "neq_vec"} : forall 'n. (bits('n), bits('n)) -> bool
val neq_vec = pure {lem: "neq_vec"} : forall 'n. (bits('n), bits('n)) -> bool
function neq_vec (x, y) = not_bool(x == y)
overload operator != = {neq_vec}
val UInt = {
val UInt = pure {
ocaml: "uint",
lem: "uint",
interpreter: "uint",
Expand Down
18 changes: 9 additions & 9 deletions test/mono/castreq.sail
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ default Order dec
$include <prelude.sail>
$include <smt.sail>

val extzv : forall 'n 'm, 'm >= 0 & 'n >= 0. (implicit('m), bitvector('n, dec)) -> bitvector('m, dec) effect pure
val extzv : forall 'n 'm, 'm >= 0 & 'n >= 0. (implicit('m), bitvector('n, dec)) -> bitvector('m, dec)
function extzv(m, v) = sail_mask(m, v)


/* Test generation of casts across case splits (e.g., going from bits('m) to bits(32)) */

val foo : forall 'm 'n, 'm in {8,16} & 'n in {32,64}. (implicit('n), bits('m)) -> bits('n) effect pure
val foo : forall 'm 'n, 'm in {8,16} & 'n in {32,64}. (implicit('n), bits('m)) -> bits('n)

function foo(n, x) =
let y : bits(16) = extzv(x) in
Expand All @@ -17,27 +17,27 @@ function foo(n, x) =
64 => let z = y@y@y@y in let dfsf = 4 in z
}

val foo_if : forall 'm 'n, 'm in {8,16} & 'n in {32,64}. (implicit('n), bits('m)) -> bits('n) effect pure
val foo_if : forall 'm 'n, 'm in {8,16} & 'n in {32,64}. (implicit('n), bits('m)) -> bits('n)

function foo_if(n, x) =
let y : bits(16) = extzv(x) in
if n == 32
then y@y
else /* 64 */ let z = y@y@y@y in let dfsf = 4 in z

val use : bits(16) -> unit effect pure
val use : bits(16) -> unit

function use(x) = ()

val bar : forall 'n, 'n in {8,16}. bits('n) -> unit effect pure
val bar : forall 'n, 'n in {8,16}. bits('n) -> unit

function bar(x) =
match 'n {
8 => use(x@x),
16 => use(x)
}

val bar_if : forall 'n, 'n in {8,16}. bits('n) -> unit effect pure
val bar_if : forall 'n, 'n in {8,16}. bits('n) -> unit

function bar_if(x) =
if 'n == 8
Expand Down Expand Up @@ -146,7 +146,7 @@ function refine_mutable_exp2(n, x) = {

/* Adding casts for top-level pattern matches */

val foo2 : forall 'm 'n, 'm in {8,16} & 'n in {32,64}. (int('n), bits('m)) -> bits('n) effect pure
val foo2 : forall 'm 'n, 'm in {8,16} & 'n in {32,64}. (int('n), bits('m)) -> bits('n)

function foo2(32,x) =
let y : bits(16) = extzv(x) in
Expand All @@ -155,7 +155,7 @@ and foo2(64,x) =
let y : bits(16) = extzv(x) in
let z = y@y@y@y in let dfsf = 4 in z

val foo3 : forall 'm 'n, 'm >= 0 & 'n in {4,8}. (int('n), bits('m)) -> bits(8 * 'n) effect pure
val foo3 : forall 'm 'n, 'm >= 0 & 'n in {4,8}. (int('n), bits('m)) -> bits(8 * 'n)

function foo3(4,x) =
let y : bits(16) = extzv(x) in
Expand All @@ -165,7 +165,7 @@ and foo3(8,x) =
let z = y@y@y@y in let dfsf = 4 in z

/* Casting an argument isn't supported yet
val bar2 : forall 'n, 'n in {8,16}. (int('n),bits('n)) -> unit effect pure
val bar2 : forall 'n, 'n in {8,16}. (int('n),bits('n)) -> unit

function bar2(8,x) =
use(x@x)
Expand Down
2 changes: 1 addition & 1 deletion test/mono/castrequnion.sail
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ default Order dec
$include <prelude.sail>
$include <mono_rewrites.sail>

val foo : forall 'm 'n, 'm in {8,16} & 'n in {16,32,64}. (implicit('n), bits('m)) -> option(bits('n)) effect pure
val foo : forall 'm 'n, 'm in {8,16} & 'n in {16,32,64}. (implicit('n), bits('m)) -> option(bits('n))

function foo(n, x) =
let y : bits(16) = sail_zero_extend(x,16) in
Expand Down
2 changes: 1 addition & 1 deletion test/mono/exint.sail
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
default Order dec
$include <prelude.sail>

val cast ex_int : int -> {'n, true. int('n)}
val ex_int : int -> {'n, true. int('n)}
function ex_int 'n = n


Expand Down
13 changes: 7 additions & 6 deletions test/mono/itself_rewriting.sail
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ $include <prelude.sail>
added in the right places, but it's also worth running in case that gets
broken. */

val needs_size_in_guard : forall 'n. int('n) -> unit
val needs_size_in_guard : forall 'n, 'n >= 0. int('n) -> unit

function needs_size_in_guard(n if n > 8) = {
let x : bits('n) = replicate_bits(0b0,n);
Expand All @@ -16,7 +16,7 @@ and needs_size_in_guard(n) = {
()
}

val no_size_in_guard : forall 'n. (int('n), int) -> unit
val no_size_in_guard : forall 'n, 'n >= 0. (int('n), int) -> unit

function no_size_in_guard((n,m) if m > 8) = {
let x : bits('n) = replicate_bits(0b0,n);
Expand All @@ -27,7 +27,7 @@ and no_size_in_guard(n,m) = {
()
}

val shadowed : forall 'n. int('n) -> unit
val shadowed : forall 'n, 'n >= 0. int('n) -> unit

function shadowed(n) = {
let n = 8;
Expand All @@ -38,13 +38,13 @@ function shadowed(n) = {
val willsplit : bool -> unit

function willsplit(x) = {
let 'n : int = if x then 8 else 16;
let 'n : nat = if x then 8 else 16;
needs_size_in_guard(n);
no_size_in_guard(n,n);
shadowed(n);
}

val execute : forall 'datasize. int('datasize) -> unit
val execute : forall 'datasize, 'datasize >= 0. int('datasize) -> unit

function execute(datasize) = {
let x : bits('datasize) = replicate_bits(0b1, datasize);
Expand All @@ -56,10 +56,11 @@ val test_execute : unit -> unit
function test_execute() = {
let exp = 4;
let 'datasize = shl_int(1, exp);
assert('datasize >= 0);
execute(datasize)
}

val transitive_itself : forall 'n. int('n) -> unit
val transitive_itself : forall 'n, 'n >= 0. int('n) -> unit

function transitive_itself(n) = {
needs_size_in_guard(n);
Expand Down
Loading

0 comments on commit c050b3d

Please sign in to comment.