Skip to content

Commit

Permalink
Allow generating a vector for enumerations containing all members (#822)
Browse files Browse the repository at this point in the history
* TC: Check that scattered definitions aren't ended twice

* TC: Allow generating a vector containing enum members
  • Loading branch information
Alasdair authored Dec 10, 2024
1 parent 58b7d87 commit 25c76c8
Show file tree
Hide file tree
Showing 13 changed files with 221 additions and 73 deletions.
5 changes: 5 additions & 0 deletions doc/asciidoc/language.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,11 @@ sail::ENUM_CONV[from=enum,type=span]
The `no_enum_number_conversions` attribute can be used to disable the
generation of these functions entirely.

We also support generating a vector containing all elements of the
enumeration, using the `enum_vector` attribute.

sail::ENUM_VECTOR[from=enum,type=span]

==== Unions
:union: sail_doc/union.json

Expand Down
17 changes: 17 additions & 0 deletions doc/examples/enum.sail
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,20 @@ function custom_conversions() -> unit = {
assert(from_number(1) == Member2);
}
$span end

val iterate : unit -> unit

$span start ENUM_VECTOR
$[enum_vector my_enum_members]
enum My_other_enum = { M1, M2, M3 }

function iterate() = {
foreach (i from 0 to (length(my_enum_members) - 1)) {
match my_enum_members[i] {
M1 => print_endline("1"),
M2 => print_endline("2"),
M3 => print_endline("3"),
}
}
}
$span end
69 changes: 48 additions & 21 deletions doc/manual.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/lib/ast_util.ml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ let get_attribute attr annot =
let get_attributes annot = annot.attrs

let find_attribute_opt attr1 attrs =
List.find_opt (fun (_, attr2, _) -> attr1 = attr2) attrs |> Option.map (fun (_, _, arg) -> arg)
List.find_opt (fun (_, attr2, _) -> attr1 = attr2) attrs |> Option.map (fun (l, _, arg) -> (l, arg))

let mk_def_annot ?doc ?(attrs = []) ?(visibility = Public) l env =
{ doc_comment = doc; attrs; visibility; loc = l; env }
Expand Down
4 changes: 3 additions & 1 deletion src/lib/ast_util.mli
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ val get_attribute : string -> uannot -> (l * attribute_data option) option

val get_attributes : uannot -> (l * string * attribute_data option) list

val find_attribute_opt : string -> (l * string * attribute_data option) list -> attribute_data option option
val find_attribute_opt : string -> (l * string * attribute_data option) list -> (l * attribute_data option) option

val mk_def_annot :
?doc:string -> ?attrs:(l * string * attribute_data option) list -> ?visibility:visibility -> l -> 'a -> 'a def_annot
Expand Down Expand Up @@ -594,6 +594,8 @@ val locate_lexp : (l -> l) -> 'a lexp -> 'a lexp

val locate_typ : (l -> l) -> typ -> typ

val locate_letbind : (l -> l) -> 'a letbind -> 'a letbind

(** Make a unique location by giving it a Parse_ast.Unique wrapper with
a generated number. *)
val unique : l -> l
Expand Down
43 changes: 31 additions & 12 deletions src/lib/type_check.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4874,20 +4874,40 @@ let rec check_typedef : Env.t -> env def_annot -> uannot type_def -> typed_def l
([DEF_aux (DEF_type (TD_aux (tdef, (l, empty_tannot))), def_annot)], env)
| TD_enum (id, ids, _) ->
let env = Env.add_enum id ids env in
(* If the enumeration has the "enum_vector" attribute, we will generate a
top-level letbinding which is a vector of all the members. *)
let def_annot, enum_vector, env =
match get_def_attribute "enum_vector" def_annot with
| Some (l, Some (AD_aux (AD_string enum_vector_name, _))) ->
let enum_vector_id = mk_id enum_vector_name in
let typ = vector_typ (nint (List.length ids)) (mk_id_typ id) in
let letbind =
mk_letbind
(mk_pat (P_typ (typ, mk_pat (P_id enum_vector_id))))
(mk_exp (E_vector (List.rev_map (fun member -> mk_exp (E_id member)) ids)))
|> locate_letbind (fun _ -> gen_loc l)
in
let defs, env = check_letdef env (mk_def_annot (gen_loc l) env) letbind in
(remove_def_attribute "enum_vector" def_annot, defs, env)
| Some (l, _) -> raise (Reporting.err_general l "Invalid enum_vector attribute")
| None -> (def_annot, [], env)
in
begin
match get_def_attribute "undefined_gen" def_annot with
| Some (_, Some (AD_aux (AD_string "forbid", _))) ->
([DEF_aux (DEF_type (TD_aux (tdef, (l, empty_tannot))), def_annot)], env)
([DEF_aux (DEF_type (TD_aux (tdef, (l, empty_tannot))), def_annot)] @ enum_vector, env)
| Some (_, Some (AD_aux (AD_string "skip", _))) ->
([DEF_aux (DEF_type (TD_aux (tdef, (l, empty_tannot))), def_annot)], Env.allow_user_undefined id env)
( [DEF_aux (DEF_type (TD_aux (tdef, (l, empty_tannot))), def_annot)] @ enum_vector,
Env.allow_user_undefined id env
)
| Some (_, Some (AD_aux (AD_string "generate", _))) | None ->
let undefined_defs = Initial_check.generate_undefined_enum id ids in
let undefined_defs, env = check_defs env undefined_defs in
let def_annot =
def_annot |> remove_def_attribute "undefined_gen"
|> add_def_attribute (gen_loc l) "undefined_gen" (undefined_skip l)
in
( DEF_aux (DEF_type (TD_aux (tdef, (l, empty_tannot))), def_annot) :: undefined_defs,
( (DEF_aux (DEF_type (TD_aux (tdef, (l, empty_tannot))), def_annot) :: undefined_defs) @ enum_vector,
Env.allow_user_undefined id env
)
| Some (attr_l, Some arg) ->
Expand Down Expand Up @@ -4950,18 +4970,17 @@ and check_scattered : Env.t -> env def_annot -> uannot scattered_def -> typed_de
match sdef with
| SD_function (id, tannot_opt) ->
( [DEF_aux (DEF_scattered (SD_aux (SD_function (id, tannot_opt), (l, empty_tannot))), def_annot)],
Env.add_scattered_id id env
Env.add_scattered_id id def_annot.attrs env
)
| SD_mapping (id, tannot_opt) ->
( [DEF_aux (DEF_scattered (SD_aux (SD_mapping (id, tannot_opt), (l, empty_tannot))), def_annot)],
Env.add_scattered_id id env
Env.add_scattered_id id def_annot.attrs env
)
| SD_end id ->
if not (Env.is_scattered_id id env) then
typ_error l (string_of_id id ^ " is not a scattered definition, so it cannot be ended")
else ([], env)
| SD_end id -> ([], Env.end_scattered_id ~at:l id env)
| SD_enum id ->
([DEF_aux (DEF_scattered (SD_aux (SD_enum id, (l, empty_tannot))), def_annot)], Env.add_scattered_enum id env)
( [DEF_aux (DEF_scattered (SD_aux (SD_enum id, (l, empty_tannot))), def_annot)],
Env.add_scattered_enum id def_annot.attrs env
)
| SD_enumcl (id, member) ->
( [DEF_aux (DEF_scattered (SD_aux (SD_enumcl (id, member), (l, empty_tannot))), def_annot)],
Env.add_enum_clause id member env
Expand Down Expand Up @@ -5001,14 +5020,14 @@ and check_scattered : Env.t -> env def_annot -> uannot scattered_def -> typed_de
let funcl_env = Env.add_typquant fcl_def_annot.loc typq env in
let funcl = check_funcl funcl_env funcl typ in
( [DEF_aux (DEF_scattered (SD_aux (SD_funcl funcl, (l, mk_tannot ~uannot funcl_env typ))), def_annot)],
Env.add_scattered_id id env
Env.add_scattered_id id def_annot.attrs env
)
| SD_mapcl (id, mapcl) ->
let typq, typ = Env.get_val_spec id env in
let mapcl_env = Env.add_typquant l typq env in
let mapcl = check_mapcl mapcl_env mapcl typ in
( [DEF_aux (DEF_scattered (SD_aux (SD_mapcl (id, mapcl), (l, empty_tannot))), def_annot)],
Env.add_scattered_id id env
Env.add_scattered_id id def_annot.attrs env
)

and check_outcome : Env.t -> outcome_spec -> untyped_def list -> outcome_spec * typed_def list * Env.t =
Expand Down
94 changes: 60 additions & 34 deletions src/lib/type_env.ml
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ type global_env = {
registers : typ env_item Bindings.t;
overloads : id list multiple_env_item Bindings.t;
outcomes : (typquant * typ * kinded_id list * id list * (typquant * typ) env_item Bindings.t) env_item Bindings.t;
scattered_ids : IdSet.t;
scattered_ids : ((l * string * Ast.attribute_data option) list, Ast.l) result Bindings.t;
outcome_instantiation : (Ast.l * typ) KBindings.t;
}

Expand All @@ -126,7 +126,7 @@ let empty_global_env =
registers = Bindings.empty;
overloads = Bindings.empty;
outcomes = Bindings.empty;
scattered_ids = IdSet.empty;
scattered_ids = Bindings.empty;
outcome_instantiation = KBindings.empty;
}

Expand Down Expand Up @@ -1022,6 +1022,18 @@ let wf_constraint ~at:at_l env (NC_aux (_, l) as nc) =
let extra, l = match l with Parse_ast.Unknown -> (" here", at_l) | _ -> ("", l) in
typ_raise l (err_because (Err_other ("Well-formedness check failed for constraint" ^ extra), err_l, err))
let string_of_mtyp (mut, typ) =
match mut with Immutable -> string_of_typ typ | Mutable -> "ref<" ^ string_of_typ typ ^ ">"
let add_local id mtyp env =
if not env.allow_bindings then typ_error (id_loc id) "Bindings are not allowed in this context";
wf_typ ~at:(id_loc id) env (snd mtyp);
if Bindings.mem id env.global.val_specs then
typ_error (id_loc id) ("Local variable " ^ string_of_id id ^ " is already bound as a function name")
else ();
typ_print (lazy (adding ^ "local binding " ^ string_of_id id ^ " : " ^ string_of_mtyp mtyp)) [@coverage off];
{ env with locals = Bindings.add id mtyp env.locals }
let add_typquant l quant env =
let rec add_quant_item env = function QI_aux (qi, _) -> add_quant_item_aux env qi
and add_quant_item_aux env = function
Expand Down Expand Up @@ -1312,17 +1324,55 @@ let add_enum' is_scattered id ids env =
env
)
let add_scattered_id id env =
update_global (fun global -> { global with scattered_ids = IdSet.add id global.scattered_ids }) env
let add_enum id ids env = add_enum' false id ids env
let is_scattered_id id env = IdSet.mem id env.global.scattered_ids
let get_enum_opt id env =
match Option.map (get_item (id_loc id) env) (Bindings.find_opt id env.global.enums) with
| Some (_, enum) -> Some (IdSet.elements enum)
| None -> None
let add_scattered_enum id env = env |> add_scattered_id id |> add_enum' true id []
let get_enum id env =
match get_enum_opt id env with
| Some enum -> enum
| None -> typ_error (id_loc id) ("Enumeration " ^ string_of_id id ^ " does not exist")
let add_enum id ids env = add_enum' false id ids env
let get_enums env = filter_items_with snd env env.global.enums
let add_scattered_id id attrs env =
let updater = function None -> Some (Ok attrs) | previous -> previous in
update_global (fun global -> { global with scattered_ids = Bindings.update id updater global.scattered_ids }) env
let add_scattered_enum id attrs env = env |> add_scattered_id id attrs |> add_enum' true id []
let is_scattered_id id env = Bindings.mem id env.global.scattered_ids
let end_scattered_id ~at:l id env =
let attrs = ref [] in
let updater = function
| None -> typ_error l (string_of_id id ^ " is not a scattered definition, so it cannot be ended")
| Some (Ok attrs') ->
attrs := attrs';
Some (Error l)
| Some (Error prev_l) ->
typ_error
(Hint ("previously ended here", prev_l, l))
("Cannot end scattered definition " ^ string_of_id id ^ " as it has already been ended")
in
let env =
update_global (fun global -> { global with scattered_ids = Bindings.update id updater global.scattered_ids }) env
in
match get_enum_opt id env with
| None -> env
| Some members -> (
match find_attribute_opt "enum_vector" !attrs with
| None -> env
| Some (_, Some (AD_aux (AD_string enum_vector_name, _))) ->
add_local (mk_id enum_vector_name) (Immutable, vector_typ (nint (List.length members)) (mk_id_typ id)) env
| Some (l, _) -> raise (Reporting.err_general l "Invalid enum_vector attribute")
)
let add_enum_clause id member env =
let env = add_scattered_id id env in
let env = add_scattered_id id [] env in
match Bindings.find_opt id env.global.enums with
| Some item ->
if not (item_in_scope env item) then
Expand All @@ -1349,18 +1399,6 @@ let add_enum_clause id member env =
)
| None -> typ_error (id_loc id) ("Enumeration " ^ string_of_id id ^ " does not exist")
let get_enum_opt id env =
match Option.map (get_item (id_loc id) env) (Bindings.find_opt id env.global.enums) with
| Some (_, enum) -> Some (IdSet.elements enum)
| None -> None
let get_enum id env =
match get_enum_opt id env with
| Some enum -> enum
| None -> typ_error (id_loc id) ("Enumeration " ^ string_of_id id ^ " does not exist")
let get_enums env = filter_items_with snd env env.global.enums
let is_record id env = Bindings.mem id env.global.records
let get_record id env =
Expand Down Expand Up @@ -1428,18 +1466,6 @@ let is_mutable id env =
let to_bool = function Mutable -> true | Immutable -> false in
match Bindings.find_opt id env.locals with Some (mut, _) -> to_bool mut | None -> false
let string_of_mtyp (mut, typ) =
match mut with Immutable -> string_of_typ typ | Mutable -> "ref<" ^ string_of_typ typ ^ ">"
let add_local id mtyp env =
if not env.allow_bindings then typ_error (id_loc id) "Bindings are not allowed in this context";
wf_typ ~at:(id_loc id) env (snd mtyp);
if Bindings.mem id env.global.val_specs then
typ_error (id_loc id) ("Local variable " ^ string_of_id id ^ " is already bound as a function name")
else ();
typ_print (lazy (adding ^ "local binding " ^ string_of_id id ^ " : " ^ string_of_mtyp mtyp)) [@coverage off];
{ env with locals = Bindings.add id mtyp env.locals }
(* Promote a set of identifiers from local bindings to top-level global letbindings *)
let add_toplevel_lets ids (env : env) =
IdSet.fold
Expand Down Expand Up @@ -1477,7 +1503,7 @@ let add_variant id (typq, constructors) env =
)
let add_scattered_variant id typq env =
let env = add_scattered_id id env in
let env = add_scattered_id id [] env in
if bound_typ_id env id then already_bound "scattered union" id env
else (
typ_print (lazy (adding ^ "scattered variant " ^ string_of_id id)) [@coverage off];
Expand All @@ -1493,7 +1519,7 @@ let add_scattered_variant id typq env =
)
let add_variant_clause id tu env =
let env = add_scattered_id id env in
let env = add_scattered_id id [] env in
match Bindings.find_opt id env.global.unions with
| Some ({ item = typq, tus; _ } as item) ->
update_global
Expand Down
5 changes: 3 additions & 2 deletions src/lib/type_env.mli
Original file line number Diff line number Diff line change
Expand Up @@ -207,16 +207,17 @@ val add_extern : id -> extern -> t -> t
val get_extern : id -> t -> string -> string

val add_enum : id -> id list -> t -> t
val add_scattered_enum : id -> t -> t
val add_scattered_enum : id -> (l * string * Ast.attribute_data option) list -> t -> t
val add_enum_clause : id -> id -> t -> t
val get_enum_opt : id -> t -> id list option
val get_enum : id -> t -> id list
val get_enums : t -> IdSet.t Bindings.t

val lookup_id : id -> t -> typ lvar

val add_scattered_id : id -> t -> t
val add_scattered_id : id -> (l * string * Ast.attribute_data option) list -> t -> t
val is_scattered_id : id -> t -> bool
val end_scattered_id : at:Ast.l -> id -> t -> t

val expand_synonyms : t -> typ -> typ
val expand_nexp_synonyms : t -> nexp -> nexp
Expand Down
4 changes: 2 additions & 2 deletions src/sail_doc_backend/docinfo.ml
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ module Generator (Converter : Markdown.CONVERTER) (Config : CONFIG) = struct
in
match find_attribute_opt "split" attrs with
| None -> None
| Some (Some (AD_aux (AD_string split_id, _))) -> (
| Some (_, Some (AD_aux (AD_string split_id, _))) -> (
let split_id = mk_id split_id in
let env = Type_check.env_of exp in
match Type_check.Env.lookup_id split_id env with
Expand Down Expand Up @@ -572,7 +572,7 @@ module Generator (Converter : Markdown.CONVERTER) (Config : CONFIG) = struct

let docinfo_for_mapcl n (MCL_aux (aux, (def_annot, _)) as clause) =
let source = doc_loc def_annot.loc Type_check.strip_mapcl Reformatter.doc_mapcl clause in
let parse_wavedrom_attr = function Some (AD_aux (AD_string s, _)) -> Some s | Some _ | None -> None in
let parse_wavedrom_attr = function _, Some (AD_aux (AD_string s, _)) -> Some s | _, Some _ | _, None -> None in
let wavedrom_attr = Option.bind (find_attribute_opt "wavedrom" def_annot.attrs) parse_wavedrom_attr in

let left, left_wavedrom, right, right_wavedrom, body =
Expand Down
2 changes: 2 additions & 0 deletions test/c/enum_vector.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
A
B
26 changes: 26 additions & 0 deletions test/c/enum_vector.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
default Order dec

$include <prelude.sail>

$option --unroll-loops

$[enum_vector E_members]
scattered enum E

enum clause E = A

enum clause E = B

end E

val main : unit -> unit

function main() = {
let xs : vector(2, E) = E_members;
foreach (i from 0 to 1) {
match xs[i] {
A => print_endline("A"),
B => print_endline("B"),
}
}
}
9 changes: 9 additions & 0 deletions test/typecheck/fail/enum_vector_scattered_dup.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Type error:
Code generated nearby:
fail/enum_vector_scattered_dup.sail:5.0-24:
5 |$[enum_vector E_members]
 |^----------------------^ Previous definition
fail/enum_vector_scattered_dup.sail:12.4-34:
12 |let E_members : vector(1, E) = [A]
 | ^----------------------------^
 | Duplicate toplevel let binding E_members
14 changes: 14 additions & 0 deletions test/typecheck/fail/enum_vector_scattered_dup.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
default Order dec

$include <prelude.sail>

$[enum_vector E_members]
scattered enum E

enum clause E = A

enum clause E = B

let E_members : vector(1, E) = [A]

end E

0 comments on commit 25c76c8

Please sign in to comment.