Skip to content

Commit

Permalink
Support undefined function generation for scattered unions
Browse files Browse the repository at this point in the history
Also adjust interpreter to collect all of the functions before attempting
to initialise the registers, so that it can find generated functions
placed at the end.

Fixes #216
  • Loading branch information
bacam committed May 5, 2023
1 parent 6299db6 commit fcdcdc5
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 58 deletions.
7 changes: 7 additions & 0 deletions src/lib/ast_util.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,13 @@ let record_ids defs =
in
IdSet.of_list (rec_ids defs)

let rec get_scattered_union_clauses id = function
| DEF_aux (DEF_scattered (SD_aux (SD_unioncl (uid, tu), _)), _) :: defs when Id.compare id uid = 0 ->
tu :: get_scattered_union_clauses id defs
| _ :: defs ->
get_scattered_union_clauses id defs
| [] -> []

let order_compare (Ord_aux (o1,_)) (Ord_aux (o2,_)) =
match o1, o2 with
| Ord_var k1, Ord_var k2 -> Kid.compare k1 k2
Expand Down
4 changes: 3 additions & 1 deletion src/lib/ast_util.mli
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,9 @@ val ids_of_ast : 'a ast -> IdSet.t
val val_spec_ids : 'a def list -> IdSet.t

val record_ids : 'a def list -> IdSet.t


val get_scattered_union_clauses : id -> 'a def list -> type_union list

val pat_ids : 'a pat -> IdSet.t

val subst : id -> 'a exp -> 'a exp -> 'a exp
Expand Down
106 changes: 59 additions & 47 deletions src/lib/initial_check.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1118,6 +1118,55 @@ let generate_undefineds vs_ids defs =
| [pat] -> pat
| pats -> mk_pat (P_tuple pats)
in
let undefined_union id typq tus =
let pat = p_tup (quant_items typq |> List.map quant_item_param |> List.concat |> List.map (fun id -> mk_pat (P_id id))) in
let body =
if !opt_fast_undefined && List.length tus > 0 then
undefined_tu (List.hd tus)
else
(* Deduplicate arguments for each constructor to keep definitions
manageable. *)
let extract_tu = function
| Tu_aux (Tu_ty_id (Typ_aux (Typ_tuple typs, _), id), _) -> (id, typs)
| Tu_aux (Tu_ty_id (typ, id), _) -> (id, [typ])
in
let record_arg_typs m (_,typs) =
let m' =
List.fold_left (fun m typ ->
TypMap.add typ (1 + try TypMap.find typ m with Not_found -> 0) m) TypMap.empty typs in
TypMap.merge (fun _ x y -> match x,y with Some m, Some n -> Some (max m n)
| None, x -> x
| x, None -> x) m m'
in
let make_undef_var typ n (i,lbs,m) =
let j = i+n in
let rec aux k =
if k = j then [] else
let v = mk_id ("u_" ^ string_of_int k) in
(mk_letbind (mk_pat (P_typ (typ,mk_pat (P_id v)))) (mk_lit_exp L_undef))::
(aux (k+1))
in
(j, aux i @ lbs, TypMap.add typ i m)
in
let make_constr m (id,typs) =
let args, _ = List.fold_right (fun typ (acc,m) ->
let i = TypMap.find typ m in
(mk_exp (E_id (mk_id ("u_" ^ string_of_int i)))::acc,
TypMap.add typ (i+1) m)) typs ([],m) in
mk_exp (E_app (id, args))
in
let constr_args = List.map extract_tu tus in
let typs_needed = List.fold_left record_arg_typs TypMap.empty constr_args in
let (_,letbinds,typ_to_var) = TypMap.fold make_undef_var typs_needed (0,[],TypMap.empty) in
List.fold_left (fun e lb -> mk_exp (E_let (lb,e)))
(mk_exp (E_app (mk_id "internal_pick",
[mk_exp (E_list (List.map (make_constr typ_to_var) constr_args))]))) letbinds
in
(mk_val_spec (VS_val_spec (undefined_typschm id typq, prepend_id "undefined_" id, None, false)),
mk_fundef [mk_funcl (prepend_id "undefined_" id)
pat
body])
in
let undefined_td = function
| TD_enum (id, ids, _) when not (IdSet.mem (prepend_id "undefined_" id) vs_ids) ->
let typschm = typschm_of_string ("unit -> " ^ string_of_id id) in
Expand All @@ -1136,58 +1185,21 @@ let generate_undefineds vs_ids defs =
pat
(mk_exp (E_struct (List.map (fun (_, id) -> mk_fexp id (mk_lit_exp L_undef)) fields)))]]
| TD_variant (id, typq, tus, _) when not (IdSet.mem (prepend_id "undefined_" id) vs_ids) ->
let pat = p_tup (quant_items typq |> List.map quant_item_param |> List.concat |> List.map (fun id -> mk_pat (P_id id))) in
let body =
if !opt_fast_undefined && List.length tus > 0 then
undefined_tu (List.hd tus)
else
(* Deduplicate arguments for each constructor to keep definitions
manageable. *)
let extract_tu = function
| Tu_aux (Tu_ty_id (Typ_aux (Typ_tuple typs, _), id), _) -> (id, typs)
| Tu_aux (Tu_ty_id (typ, id), _) -> (id, [typ])
in
let record_arg_typs m (_,typs) =
let m' =
List.fold_left (fun m typ ->
TypMap.add typ (1 + try TypMap.find typ m with Not_found -> 0) m) TypMap.empty typs in
TypMap.merge (fun _ x y -> match x,y with Some m, Some n -> Some (max m n)
| None, x -> x
| x, None -> x) m m'
in
let make_undef_var typ n (i,lbs,m) =
let j = i+n in
let rec aux k =
if k = j then [] else
let v = mk_id ("u_" ^ string_of_int k) in
(mk_letbind (mk_pat (P_typ (typ,mk_pat (P_id v)))) (mk_lit_exp L_undef))::
(aux (k+1))
in
(j, aux i @ lbs, TypMap.add typ i m)
in
let make_constr m (id,typs) =
let args, _ = List.fold_right (fun typ (acc,m) ->
let i = TypMap.find typ m in
(mk_exp (E_id (mk_id ("u_" ^ string_of_int i)))::acc,
TypMap.add typ (i+1) m)) typs ([],m) in
mk_exp (E_app (id, args))
in
let constr_args = List.map extract_tu tus in
let typs_needed = List.fold_left record_arg_typs TypMap.empty constr_args in
let (_,letbinds,typ_to_var) = TypMap.fold make_undef_var typs_needed (0,[],TypMap.empty) in
List.fold_left (fun e lb -> mk_exp (E_let (lb,e)))
(mk_exp (E_app (mk_id "internal_pick",
[mk_exp (E_list (List.map (make_constr typ_to_var) constr_args))]))) letbinds
in
[mk_val_spec (VS_val_spec (undefined_typschm id typq, prepend_id "undefined_" id, None, false));
mk_fundef [mk_funcl (prepend_id "undefined_" id)
pat
body]]
let vs, def = undefined_union id typq tus in
[vs; def]
| _ -> []
in
let undefined_scattered id typq =
let tus = get_scattered_union_clauses id defs in
undefined_union id typq tus
in
let rec undefined_defs = function
| DEF_aux (DEF_type (TD_aux (td_aux, _)), _) as def :: defs ->
def :: undefined_td td_aux @ undefined_defs defs
(* The function definition must come after the scattered type definition is complete, so put it at the end. *)
| DEF_aux (DEF_scattered (SD_aux (SD_variant (id, typq), _)), _) as def :: defs ->
let vs, fn = undefined_scattered id typq in
def :: vs :: undefined_defs defs @ [fn]
| def :: defs ->
def :: undefined_defs defs
| [] -> []
Expand Down
8 changes: 6 additions & 2 deletions src/lib/interpreter.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1019,8 +1019,6 @@ let rec initialize_registers allow_registers gstate =
| Some exp ->
{ gstate with registers = Bindings.add id (eval_exp (initial_lstate, gstate) exp) gstate.registers }
end
| DEF_aux (DEF_fundef fdef, _) ->
{ gstate with fundefs = Bindings.add (id_of_fundef fdef) fdef gstate.fundefs }
| _ -> gstate
in
function
Expand All @@ -1030,6 +1028,12 @@ let rec initialize_registers allow_registers gstate =

let initial_state ?(registers=true) ast env primops =
let gstate = initial_gstate primops ast.defs env in
let add_function gstate = function
| DEF_aux (DEF_fundef fdef, _) ->
{ gstate with fundefs = Bindings.add (id_of_fundef fdef) fdef gstate.fundefs }
| _ -> gstate
in
let gstate = List.fold_left add_function gstate ast.defs in
let gstate =
{ (initialize_registers registers gstate ast.defs)
with allow_registers = registers }
Expand Down
9 changes: 1 addition & 8 deletions src/lib/scattered.ml
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,6 @@ let fake_rec_opt l = Rec_aux (Rec_nonrec, gen_loc l)

let no_tannot_opt l = Typ_annot_opt_aux (Typ_annot_opt_none, gen_loc l)

let rec get_union_clauses id = function
| DEF_aux (DEF_scattered (SD_aux (SD_unioncl (uid, tu), _)), _) :: defs when Id.compare id uid = 0 ->
tu :: get_union_clauses id defs
| _ :: defs ->
get_union_clauses id defs
| [] -> []

let rec filter_union_clauses id = function
| DEF_aux (DEF_scattered (SD_aux (SD_unioncl (uid, tu), _)), _) :: defs when Id.compare id uid = 0 ->
filter_union_clauses id defs
Expand Down Expand Up @@ -165,7 +158,7 @@ let rec descatter' funcls mapcls = function
immediately grab all the future clauses and turn it into a
regular union declaration. *)
| DEF_aux (DEF_scattered (SD_aux (SD_variant (id, typq), (l, _))), def_annot) :: defs ->
let tus = get_union_clauses id defs in
let tus = get_scattered_union_clauses id defs in
begin match tus with
| [] -> raise (Reporting.err_general l "No clauses found for scattered union type")
| _ ->
Expand Down

0 comments on commit fcdcdc5

Please sign in to comment.